跳至内容
什么是广播 Broadcasting

什么是广播 Broadcasting

广播broadcasting

好问题 👍,**广播(broadcasting)**是 PyTorch(和 NumPy)里的一个重要机制,用来让不同形状的张量在运算时“自动扩展”成兼容的形状,而不用你手动复制数据。


1. 广播的直观理解

你可以把它理解成 把小张量复制成大张量,让它们在维度上对齐后再计算。 但是实际并不会真的复制数据(否则内存太浪费),只是计算时“假装”扩展了。


2. 广播的规则

假设有两个张量 A 和 B,它们形状分别是:

A.shape = (m, n)
B.shape = (1, n)

广播时,PyTorch 会比较它们的维度:

  1. 从最后一个维度开始对齐。
  2. 如果两个维度相同 → 没问题。
  3. 如果一个是 1,另一个是 k → 自动扩展成 k
  4. 如果都不一样且都不是 1 → 报错。

比如:

(4, 3)  
(1, 3)  →  (4, 3)   # 复制第一维

3. 回到你的例子

1
torch.mean(self.root_states[:, 2].unsqueeze(1) - self.measured_heights, dim=1)
  • self.root_states[:, 2].unsqueeze(1) → shape 是 (N, 1),表示每个机器人当前的高度。
  • self.measured_heights → shape 是 (N, M),表示机器人脚下不同探测点的高度。

相减时:

(N, 1)   -   (N, M)

广播规则会把 (N, 1) 扩展成 (N, M),相当于每个高度值都复制到 M 列。 这样就能做逐元素相减,结果也是 (N, M)

最后:

1
torch.mean(..., dim=1)

对第 1 维(每行的 M 个探测点)取平均,得到 (N,),就是每个机器人的平均高度差。


4. 用列表类比

假设:

1
2
3
4
a = [[2],
     [3]]
b = [[1, 1, 1],
     [4, 5, 6]]

相减时,a 会被广播成:

[[2, 2, 2],
 [3, 3, 3]]

然后:

[[2-1, 2-1, 2-1],
 [3-4, 3-5, 3-6]]
= [[1, 1, 1],
   [-1, -2, -3]]

再按行取平均 → [1, -2]


要不要我给你画一张图,把 (N,1) → (N,M) 的广播过程可视化,这样你更直观地理解?