You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch是否提供topk的反向操作以返回最小的k个元素?

PyTorch是否提供topk的反向操作以返回最小的k个元素?

嘿,这个问题问得很实用!其实PyTorch里并没有原生的torch.mink函数,不过你完全不需要自己绕弯子(比如取负再用topk),官方早就给torch.topk准备了更直接的解决方案。

  • 最推荐的方式:给torch.topk加上largest=False参数,直接获取最小的k个元素
    举个例子,用你给出的张量测试:

    import torch
    x = torch.arange(1., 6.)
    print(x)
    # >>> tensor([1., 2., 3., 4., 5.])
    
    # 获取最小的3个元素及其索引
    values, indices = torch.topk(x, 3, largest=False)
    print(values)
    # >>> tensor([1., 2., 3.])
    print(indices)
    # >>> tensor([0, 1, 2])
    

    这个方法不仅直观,还能避免取负操作可能带来的潜在问题(比如处理包含infNaN的张量时,取负可能会出现意外结果),而且和topk本身的性能一致,效率很高。

  • 另一种备选方案:先排序再取前k个
    如果你想先用torch.sort对整个张量排序,再截取前k个元素,也是可行的,但这种方法在张量规模较大时,性能不如topk(largest=False)(因为排序会处理整个张量,而topk只聚焦于需要的k个元素):

    sorted_vals, sorted_indices = torch.sort(x)
    min_k_vals = sorted_vals[:3]
    min_k_indices = sorted_indices[:3]
    

你之前想到的用torch.topk(-x, 3)的方法虽然能得到结果,但确实不是最优解,官方更推荐用largest=False的参数形式来实现“取最小k个元素”的需求。

备注:内容来源于stack exchange,提问作者Penguin

火山引擎 最新活动