PyTorch/Numpy批量矩阵运算:如何用torch.bmm实现指定计算?
解决PyTorch中用bmm实现矩阵与批量向量的乘法问题
我完全懂你的困惑——torch.bmm确实要求两个输入都是3D批量张量,且批量维度得放在第一位置,但你的第一个输入是2D矩阵,没法直接用。不过我们可以通过调整张量形状来适配,同时也给你推荐更简便的替代方案,顺便覆盖Numpy用户的情况。
PyTorch 实现方法
方法1:调整形状适配torch.bmm
按照bmm的规则,我们需要把两个输入都转换成3D张量,且保证批量维度一致:
- 对
matrix(M×N):添加一个批量维度,再重复B次,得到B×M×N的张量(这样每个批量样本都是同一个M×N矩阵) - 对
batch(N×B):先转置成B×N,再添加最后一维变成B×N×1(这样每个批量样本是N×1的向量) - 执行
bmm后得到B×M×1的结果,最后挤压掉最后一维并转置,就能得到目标的M×B张量
代码示例:
import torch M, N, B = 5, 3, 4 matrix = torch.randn(M, N) batch = torch.randn(N, B) # 调整形状适配bmm matrix_batched = matrix.unsqueeze(0).repeat(B, 1, 1) # shape: (B, M, N) batch_batched = batch.T.unsqueeze(-1) # shape: (B, N, 1) # 执行批量矩阵乘法 result_bmm = torch.bmm(matrix_batched, batch_batched) # shape: (B, M, 1) final_result = result_bmm.squeeze(-1).T # shape: (M, B)
方法2:更简便的torch.matmul(推荐)
其实你根本没必要用bmm!torch.matmul支持广播机制,直接对2D的matrix和batch做乘法就能得到想要的结果,代码简洁又高效:
final_result = torch.matmul(matrix, batch) # shape: (M, B) # 或者直接用@运算符,和matmul等价 final_result = matrix @ batch
这个方法既满足需求,又避免了不必要的张量重复,性能更好。
Numpy 对应实现
Numpy用户遇到的问题类似,同样可以用两种方式解决:
方法1:调整形状适配批量乘法
import numpy as np M, N, B = 5, 3, 4 matrix = np.random.randn(M, N) batch = np.random.randn(N, B) # 调整形状 matrix_batched = np.tile(matrix[np.newaxis, :, :], (B, 1, 1)) # shape: (B, M, N) batch_batched = batch.T[:, :, np.newaxis] # shape: (B, N, 1) # 批量乘法 result_bmm = np.matmul(matrix_batched, batch_batched) # shape: (B, M, 1) final_result = result_bmm.squeeze(-1).T # shape: (M, B)
方法2:直接用矩阵乘法(推荐)
Numpy的@运算符或者np.matmul同样支持这种维度的乘法,一步到位:
final_result = matrix @ batch # shape: (M, B) # 或者 final_result = np.matmul(matrix, batch)
总结一下,如果你只是需要完成matrix @ batch的运算,直接用matmul或者@运算符是最优解;如果一定要用bmm(比如有特殊批量场景需求),就通过调整形状来适配即可。
内容的提问来源于stack exchange,提问作者Nick




