什么是广播 Broadcasting
广播broadcasting
好问题 👍,**广播(broadcasting)**是 PyTorch(和 NumPy)里的一个重要机制,用来让不同形状的张量在运算时“自动扩展”成兼容的形状,而不用你手动复制数据。
1. 广播的直观理解
你可以把它理解成 把小张量复制成大张量,让它们在维度上对齐后再计算。 但是实际并不会真的复制数据(否则内存太浪费),只是计算时“假装”扩展了。
2. 广播的规则
假设有两个张量 A 和 B,它们形状分别是:
A.shape = (m, n)
B.shape = (1, n)广播时,PyTorch 会比较它们的维度:
- 从最后一个维度开始对齐。
- 如果两个维度相同 → 没问题。
- 如果一个是
1,另一个是k→ 自动扩展成k。 - 如果都不一样且都不是
1→ 报错。
比如:
(4, 3)
(1, 3) → (4, 3) # 复制第一维3. 回到你的例子
| |
self.root_states[:, 2].unsqueeze(1)→ shape 是(N, 1),表示每个机器人当前的高度。self.measured_heights→ shape 是(N, M),表示机器人脚下不同探测点的高度。
相减时:
(N, 1) - (N, M)广播规则会把 (N, 1) 扩展成 (N, M),相当于每个高度值都复制到 M 列。
这样就能做逐元素相减,结果也是 (N, M)。
最后:
| |
对第 1 维(每行的 M 个探测点)取平均,得到 (N,),就是每个机器人的平均高度差。
4. 用列表类比
假设:
| |
相减时,a 会被广播成:
[[2, 2, 2],
[3, 3, 3]]然后:
[[2-1, 2-1, 2-1],
[3-4, 3-5, 3-6]]
= [[1, 1, 1],
[-1, -2, -3]]再按行取平均 → [1, -2]
要不要我给你画一张图,把 (N,1) → (N,M) 的广播过程可视化,这样你更直观地理解?