You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何用argsort对ranks按列排序,用其索引重排unsorted并取前2行(兼顾性能)

解决按列排序索引重排数组的问题

首先,你的代码有两个关键问题导致无法正常运行,同时针对大型数组的性能需求,我们可以用numpy的专用函数来高效解决:

问题分析

  1. 索引方式错误:你直接用unsorted[ind]会生成一个形状为(10,5,5)的数组,而不是期望的(10,5)。这是因为numpy会把ind中的每个元素当作行索引,对每行都取出整个unsorted数组的行,导致维度膨胀。
  2. 变量名冲突sorted是Python的内置函数,用它作为变量名会覆盖这个内置函数,可能引发后续问题,建议改用其他名称。

另外确认一下:你说的“按列对ranks排序”应该是指对每一列单独排序(即沿着行的方向,把每列的元素从小到大排列),如果是这个需求,那argsortaxis参数应该设为0(你之前用了axis=1,这是对每行的列进行排序)。

正确高效的代码实现

我们可以用numpy的np.take_along_axis函数,它专门用于沿着指定轴应用排序后的索引,性能经过优化,非常适合处理大型数组:

import numpy as np

# 生成测试数据
ranks = np.random.uniform(0, 1, (10, 5))
unsorted = np.random.uniform(0, 1, (10, 5))

# 按列对ranks排序,得到每列的排序索引(axis=0表示沿列方向排序)
ind = np.argsort(ranks, axis=0)

# 用索引重排unsorted数组,沿列方向应用索引
sorted_unsorted = np.take_along_axis(unsorted, ind, axis=0)

# 选取前2行
result = sorted_unsorted[:2, :]

性能说明

np.take_along_axis是numpy底层优化的函数,用C实现,避免了手动索引可能产生的中间数组复制,对于50000×5000这样的大型数组,它的执行效率远高于手动构造广播索引的方式,能有效节省内存和计算时间。

如果你的需求是对每行的列按ranks的行排序(即每行内的列重新排列),只需要把axis参数都改成1即可:

ind = np.argsort(ranks, axis=1)
sorted_unsorted = np.take_along_axis(unsorted, ind, axis=1)
result = sorted_unsorted[:2, :]

内容的提问来源于stack exchange,提问作者lara_toff

火山引擎 最新活动