You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何在PyTorch中实现批量masked_select操作?

当然有不少更简洁或者更高效的替代方案!原方案用循环处理每行虽然直观,但在处理大规模张量时,我们可以利用PyTorch的向量化操作来优化性能,同时让代码更紧凑。下面分享几种可行的实现方式:

方案1:全向量化操作(无显式循环,适合大规模数据)

这个方法通过构造索引矩阵,直接将选中的元素映射到目标位置,完全避免了循环,性能最优:

import torch

x = torch.tensor([[1., 2., 2., 2., 3.], [1., 2., 4., 3., 2.]])
masks = torch.tensor([[True, False, False, False, True], [True, False, True, True, False]])

# 计算每行需要保留的元素数量
keep_counts = masks.sum(dim=1)  # 得到 tensor([2, 3])
batch_size, seq_len = x.shape

# 创建行索引矩阵,用于定位每行元素
row_indices = torch.arange(batch_size).unsqueeze(1).repeat(1, seq_len)
# 创建列索引矩阵,用于标记目标位置
col_indices = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)

# 生成掩码:只保留每行前N个位置(N为该行要保留的元素数)用于填充选中元素
valid_pos_mask = col_indices < keep_counts.unsqueeze(1)

# 获取所有mask为True的元素的列索引,并按行分组
selected_cols = torch.where(masks)[1].reshape(batch_size, -1)

# 初始化结果为全1张量,然后将选中元素填入对应位置
result = torch.ones_like(x)
result[valid_pos_mask] = x[row_indices[masks], selected_cols.flatten()]

print(result)
# 输出:
# tensor([[1., 3., 1., 1., 1.],
#         [1., 4., 3., 1., 1.]])

方案2:简化循环逻辑(更紧凑的循环实现)

如果觉得全向量化的索引处理有点复杂,也可以用split方法直接按行拆分选中的元素,让循环代码更简洁:

import torch

x = torch.tensor([[1., 2., 2., 2., 3.], [1., 2., 4., 3., 2.]])
masks = torch.tensor([[True, False, False, False, True], [True, False, True, True, False]])

result = torch.ones_like(x)
# 将所有选中的元素按行拆分,得到每行的选中元素列表
selected_elements = torch.masked_select(x, masks).split(masks.sum(dim=1).tolist())

# 遍历每行,将选中元素放到行首
for idx, elems in enumerate(selected_elements):
    result[idx, :len(elems)] = elems

print(result)

方案3:固定长度场景下的极简实现

如果你的场景中,每行需要保留的元素数量是固定的(比如所有行都保留2个元素),可以用pad函数快速完成填充:

import torch
import torch.nn.functional as F

x = torch.tensor([[1., 2., 2., 2., 3.], [1., 2., 4., 3., 2.]])
masks = torch.tensor([[True, False, False, False, True], [True, False, True, True, False]])

# 提取所有选中元素并按行重组
selected = torch.masked_select(x, masks).reshape(x.shape[0], -1)
# 计算需要填充的1的数量,用pad补全到原张量长度
pad_length = x.shape[1] - selected.shape[1]
result = F.pad(selected, (0, pad_length), mode='constant', value=1.)

print(result)

⚠️ 注意:这个方法仅适用于所有行保留元素数量相同的情况,如果每行数量不一致,会报错。

内容的提问来源于stack exchange,提问作者Daniel

火山引擎 最新活动