首页 > NumPy

NumPy广播机制剖析

NumPy 的 Universal functions 中要求输入的数组 shape 是一致的,当数组的 shape 不相等时,则会使用广播机制。不过,调整数组使得 shape 一样,需要满足一定的规则,否则将出错。这些规则可归纳为以下 4 条。

1) 让所有输入数组都向其中 shape 最长的数组看齐,不足的部分则通过在前面加 1 补齐,如:

a:2×3×2
b:3×2

则 b 向 a 看齐,在 b 的前面加 1,变为 1×3×2。

2) 输出数组的 shape 是输入数组 shape 的各个轴上的最大值。

3) 如果输入数组的某个轴和输出数组的对应轴的长度相同或者某个轴的长度为 1 时,这个数组能被用来计算,否则出错。

4) 当输入数组的某个轴的长度为 1 时,沿着此轴运算时都用(或复制)此轴上的第一组值。

广播在整个 NumPy 中用于决定如何处理形状迥异的数组,涉及的算术运算包括(+,-,*,/…)。这些规则说得很严谨,但不直观,下面我们结合图形与代码来进一步说明。

目的:计算A+B,其中 A 为 4×1 矩阵,B 为一维向量 (3,)。

要相加,需要做如下处理:

NumPy 广播规则示意图
图1:NumPy 广播规则示意图

请看下面的代码实现:
import numpy as np
A = np.arange(0, 40,10).reshape(4, 1)
B = np.arange(0, 3)
print("A矩阵的形状:{},B矩阵的形状:{}".format(A.shape,B.shape))
C=A+B
print("C矩阵的形状:{}".format(C.shape))
print(C)
运行结果:

A矩阵的形状:(4, 1),B矩阵的形状:(3,)
C矩阵的形状:(4, 3)
[[ 0  1  2]
  [10 11 12]
  [20 21 22]
  [30 31 32]]

所有教程

优秀文章