如何用torch.gather实现3D张量按2D邻域索引提取[B,N,K,F]形状输出?
如何高效提取PyTorch张量的邻域元素
我来帮你搞定这个邻域提取的问题!你的思路其实已经找对方向了,但之前的实现有点冗余(比如没必要重复索引F次),而且维度处理可以更高效。咱们一步步来,先明确需求,再给出几种简洁靠谱的实现方式。
需求回顾
- 输入
data:形状[B, N, F],其中B是批次数量,N是点的总数,F是每个点的特征维度 - 输入
indices:形状[N, K],每个点对应K个邻域点的索引(索引针对N这个维度) - 目标输出:形状
[B, N, K, F],每个批次下的每个点,都包含它的K个邻域点的特征
示例数据先备好
先把你给的示例用PyTorch张量定义好,方便测试:
import torch data = torch.tensor([[[1,2,3],[2,3,4],[3,4,5]],[[3,4,5],[6,7,8],[2,3,4]]]) # shape [2,3,3] indices = torch.tensor([[1,2],[1,0],[0,0]]) # shape [3,2] B, N, F = data.shape K = indices.shape[1]
方法1:高级索引(最推荐,高效简洁)
利用PyTorch的广播索引特性,不需要额外重复数据,直接一步到位:
# 生成批次索引,形状[B,1,1],可以和indices的[1,N,K]广播匹配 batch_idx = torch.arange(B).unsqueeze(1).unsqueeze(2) # 直接索引,自动广播后得到[B,N,K,F]的结果 output = data[batch_idx, indices.unsqueeze(0), :]
测试一下结果:
print(output[0][0]) # 输出tensor([[2, 3, 4], [3, 4, 5]]),完全符合你的要求 print(output.shape) # 输出torch.Size([2, 3, 2, 3]),正确
这种方法没有多余的内存开销,代码也最简洁,是首选方案。
方法2:使用take_along_dim(语义更直观)
如果你用的是PyTorch 1.10及以上版本,take_along_dim比gather语义更清晰,用起来更顺手:
# 把indices扩展为[B,N,K],每个批次复用相同的索引 batched_indices = indices.unsqueeze(0).repeat(B, 1, 1) # 沿着dim=1的维度,用索引选取元素,最后扩展维度匹配特征维度 output = torch.take_along_dim(data.unsqueeze(2), batched_indices.unsqueeze(-1), dim=1)
这个方法的效果和高级索引完全一致,只是写法不同,适合喜欢语义明确API的同学。
方法3:修正你的gather实现
你的原始思路是对的,但可以优化掉不必要的重复操作:
# 扩展索引到[B,N,K,1],只需要重复批次维度,特征维度不需要重复(PyTorch会自动广播) batched_indices = indices.unsqueeze(0).unsqueeze(-1).repeat(B, 1, 1, 1) # 把data扩展为[B,N,1,F],为邻域维度预留位置 data_expanded = data.unsqueeze(2) # 在dim=1维度上执行gather,得到目标形状 output = torch.gather(data_expanded, dim=1, index=batched_indices)
这里要注意:你之前把batched_indices重复了F次,其实完全没必要——gather在执行时会自动把索引的最后一个维度广播到特征维度,重复F次只会浪费内存,尤其是当F很大的时候。
为什么这些方法能工作?
核心是让索引的维度和数据的维度正确对齐:
- 批次维度:每个批次的邻域索引是相同的,所以只需要把
indices扩展出批次维度,或者用批次索引广播匹配 - 邻域维度:我们需要给每个点添加K个邻域的位置,所以要么扩展数据的维度,要么用索引的K维度去匹配
内容的提问来源于stack exchange,提问作者DsCpp




