如何用argsort对ranks按列排序,用其索引重排unsorted并取前2行(兼顾性能)
解决按列排序索引重排数组的问题
首先,你的代码有两个关键问题导致无法正常运行,同时针对大型数组的性能需求,我们可以用numpy的专用函数来高效解决:
问题分析
- 索引方式错误:你直接用
unsorted[ind]会生成一个形状为(10,5,5)的数组,而不是期望的(10,5)。这是因为numpy会把ind中的每个元素当作行索引,对每行都取出整个unsorted数组的行,导致维度膨胀。 - 变量名冲突:
sorted是Python的内置函数,用它作为变量名会覆盖这个内置函数,可能引发后续问题,建议改用其他名称。
另外确认一下:你说的“按列对ranks排序”应该是指对每一列单独排序(即沿着行的方向,把每列的元素从小到大排列),如果是这个需求,那argsort的axis参数应该设为0(你之前用了axis=1,这是对每行的列进行排序)。
正确高效的代码实现
我们可以用numpy的np.take_along_axis函数,它专门用于沿着指定轴应用排序后的索引,性能经过优化,非常适合处理大型数组:
import numpy as np # 生成测试数据 ranks = np.random.uniform(0, 1, (10, 5)) unsorted = np.random.uniform(0, 1, (10, 5)) # 按列对ranks排序,得到每列的排序索引(axis=0表示沿列方向排序) ind = np.argsort(ranks, axis=0) # 用索引重排unsorted数组,沿列方向应用索引 sorted_unsorted = np.take_along_axis(unsorted, ind, axis=0) # 选取前2行 result = sorted_unsorted[:2, :]
性能说明
np.take_along_axis是numpy底层优化的函数,用C实现,避免了手动索引可能产生的中间数组复制,对于50000×5000这样的大型数组,它的执行效率远高于手动构造广播索引的方式,能有效节省内存和计算时间。
如果你的需求是对每行的列按ranks的行排序(即每行内的列重新排列),只需要把axis参数都改成1即可:
ind = np.argsort(ranks, axis=1) sorted_unsorted = np.take_along_axis(unsorted, ind, axis=1) result = sorted_unsorted[:2, :]
内容的提问来源于stack exchange,提问作者lara_toff




