You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何使用单个Numpy操作实现按索引数组提取三维矩阵的对应元素

如何使用单个Numpy操作实现按索引数组提取三维矩阵的对应元素

嗨,这个需求用Numpy的高级索引就能一步搞定,比列表推导式更高效,完全不用循环~

先回顾下你的场景:你有一个形状为(4, 3, 2)的3D数组entries,还有一个长度为4的索引数组indices,每个元素对应要提取的第二个维度的行索引,最终想得到形状为(4, 2)的结果数组。

核心解决方案

把你原来的列表推导式替换成下面这行代码就行:

import numpy as np

entry = np.array([[1, 2],[3,4],[5, 6]])
entries = np.stack([entry, entry, entry, entry])
indices = np.array([2, 1, 0, 1])

# 单个Numpy操作实现提取
r = entries[np.arange(entries.shape[0]), indices]

# 查看结果
print(r)
# 输出:
# [[5 6]
#  [3 4]
#  [1 2]
#  [3 4]]

为什么这个方法有效?

这里的关键是成对索引匹配

  • np.arange(entries.shape[0])生成了第一个维度的索引数组[0,1,2,3],对应每个要处理的2D子数组
  • 把它和indices数组[2,1,0,1]组合,Numpy会自动将这两个一维数组配对成(0,2), (1,1), (2,0), (3,1)这样的索引对,然后从entries中提取每个索引对对应的整行(第三个维度的所有元素),最终直接得到形状为(4,2)的结果。

关于np.take的小补充

你之前用np.take得到了不符合预期的形状,是因为没正确设置索引的维度和axis参数。如果一定要用take,可以这样写:

r = np.take(entries, indices, axis=1)[np.arange(4), np.arange(4)]

不过显然还是高级索引的写法更直观简洁,所以更推荐前者。

结果验证

你可以对比原来的列表推导式结果和新方法的结果,完全一致:

# 原来的列表推导式实现
r_old = np.array([entries[i, indices[i], :] for i in range(len(indices))])
print(np.array_equal(r, r_old))  # 输出:True

备注:内容来源于stack exchange,提问作者Thomas Koller

火山引擎 最新活动