PyTorch是否提供topk的反向操作以返回最小的k个元素?
PyTorch是否提供topk的反向操作以返回最小的k个元素?
嘿,这个问题问得很实用!其实PyTorch里并没有原生的torch.mink函数,不过你完全不需要自己绕弯子(比如取负再用topk),官方早就给torch.topk准备了更直接的解决方案。
最推荐的方式:给
torch.topk加上largest=False参数,直接获取最小的k个元素
举个例子,用你给出的张量测试:import torch x = torch.arange(1., 6.) print(x) # >>> tensor([1., 2., 3., 4., 5.]) # 获取最小的3个元素及其索引 values, indices = torch.topk(x, 3, largest=False) print(values) # >>> tensor([1., 2., 3.]) print(indices) # >>> tensor([0, 1, 2])这个方法不仅直观,还能避免取负操作可能带来的潜在问题(比如处理包含
inf或NaN的张量时,取负可能会出现意外结果),而且和topk本身的性能一致,效率很高。另一种备选方案:先排序再取前k个
如果你想先用torch.sort对整个张量排序,再截取前k个元素,也是可行的,但这种方法在张量规模较大时,性能不如topk(largest=False)(因为排序会处理整个张量,而topk只聚焦于需要的k个元素):sorted_vals, sorted_indices = torch.sort(x) min_k_vals = sorted_vals[:3] min_k_indices = sorted_indices[:3]
你之前想到的用torch.topk(-x, 3)的方法虽然能得到结果,但确实不是最优解,官方更推荐用largest=False的参数形式来实现“取最小k个元素”的需求。
备注:内容来源于stack exchange,提问作者Penguin




