You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何用torch.gather实现3D张量按2D邻域索引提取[B,N,K,F]形状输出?

如何高效提取PyTorch张量的邻域元素

我来帮你搞定这个邻域提取的问题!你的思路其实已经找对方向了,但之前的实现有点冗余(比如没必要重复索引F次),而且维度处理可以更高效。咱们一步步来,先明确需求,再给出几种简洁靠谱的实现方式。

需求回顾

  • 输入data:形状[B, N, F],其中B是批次数量,N是点的总数,F是每个点的特征维度
  • 输入indices:形状[N, K],每个点对应K个邻域点的索引(索引针对N这个维度)
  • 目标输出:形状[B, N, K, F],每个批次下的每个点,都包含它的K个邻域点的特征

示例数据先备好

先把你给的示例用PyTorch张量定义好,方便测试:

import torch

data = torch.tensor([[[1,2,3],[2,3,4],[3,4,5]],[[3,4,5],[6,7,8],[2,3,4]]])  # shape [2,3,3]
indices = torch.tensor([[1,2],[1,0],[0,0]])  # shape [3,2]
B, N, F = data.shape
K = indices.shape[1]

方法1:高级索引(最推荐,高效简洁)

利用PyTorch的广播索引特性,不需要额外重复数据,直接一步到位:

# 生成批次索引,形状[B,1,1],可以和indices的[1,N,K]广播匹配
batch_idx = torch.arange(B).unsqueeze(1).unsqueeze(2)
# 直接索引,自动广播后得到[B,N,K,F]的结果
output = data[batch_idx, indices.unsqueeze(0), :]

测试一下结果:

print(output[0][0])  # 输出tensor([[2, 3, 4], [3, 4, 5]]),完全符合你的要求
print(output.shape)  # 输出torch.Size([2, 3, 2, 3]),正确

这种方法没有多余的内存开销,代码也最简洁,是首选方案。

方法2:使用take_along_dim(语义更直观)

如果你用的是PyTorch 1.10及以上版本,take_along_dimgather语义更清晰,用起来更顺手:

# 把indices扩展为[B,N,K],每个批次复用相同的索引
batched_indices = indices.unsqueeze(0).repeat(B, 1, 1)
# 沿着dim=1的维度,用索引选取元素,最后扩展维度匹配特征维度
output = torch.take_along_dim(data.unsqueeze(2), batched_indices.unsqueeze(-1), dim=1)

这个方法的效果和高级索引完全一致,只是写法不同,适合喜欢语义明确API的同学。

方法3:修正你的gather实现

你的原始思路是对的,但可以优化掉不必要的重复操作:

# 扩展索引到[B,N,K,1],只需要重复批次维度,特征维度不需要重复(PyTorch会自动广播)
batched_indices = indices.unsqueeze(0).unsqueeze(-1).repeat(B, 1, 1, 1)
# 把data扩展为[B,N,1,F],为邻域维度预留位置
data_expanded = data.unsqueeze(2)
# 在dim=1维度上执行gather,得到目标形状
output = torch.gather(data_expanded, dim=1, index=batched_indices)

这里要注意:你之前把batched_indices重复了F次,其实完全没必要——gather在执行时会自动把索引的最后一个维度广播到特征维度,重复F次只会浪费内存,尤其是当F很大的时候。

为什么这些方法能工作?

核心是让索引的维度和数据的维度正确对齐:

  • 批次维度:每个批次的邻域索引是相同的,所以只需要把indices扩展出批次维度,或者用批次索引广播匹配
  • 邻域维度:我们需要给每个点添加K个邻域的位置,所以要么扩展数据的维度,要么用索引的K维度去匹配

内容的提问来源于stack exchange,提问作者DsCpp

火山引擎 最新活动