You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch/Numpy批量矩阵运算:如何用torch.bmm实现指定计算?

解决PyTorch中用bmm实现矩阵与批量向量的乘法问题

我完全懂你的困惑——torch.bmm确实要求两个输入都是3D批量张量,且批量维度得放在第一位置,但你的第一个输入是2D矩阵,没法直接用。不过我们可以通过调整张量形状来适配,同时也给你推荐更简便的替代方案,顺便覆盖Numpy用户的情况。

PyTorch 实现方法

方法1:调整形状适配torch.bmm

按照bmm的规则,我们需要把两个输入都转换成3D张量,且保证批量维度一致:

  1. matrix(M×N):添加一个批量维度,再重复B次,得到B×M×N的张量(这样每个批量样本都是同一个M×N矩阵)
  2. batch(N×B):先转置成B×N,再添加最后一维变成B×N×1(这样每个批量样本是N×1的向量)
  3. 执行bmm后得到B×M×1的结果,最后挤压掉最后一维并转置,就能得到目标的M×B张量

代码示例:

import torch

M, N, B = 5, 3, 4
matrix = torch.randn(M, N)
batch = torch.randn(N, B)

# 调整形状适配bmm
matrix_batched = matrix.unsqueeze(0).repeat(B, 1, 1)  # shape: (B, M, N)
batch_batched = batch.T.unsqueeze(-1)  # shape: (B, N, 1)

# 执行批量矩阵乘法
result_bmm = torch.bmm(matrix_batched, batch_batched)  # shape: (B, M, 1)
final_result = result_bmm.squeeze(-1).T  # shape: (M, B)

方法2:更简便的torch.matmul(推荐)

其实你根本没必要用bmmtorch.matmul支持广播机制,直接对2D的matrixbatch做乘法就能得到想要的结果,代码简洁又高效:

final_result = torch.matmul(matrix, batch)  # shape: (M, B)
# 或者直接用@运算符,和matmul等价
final_result = matrix @ batch

这个方法既满足需求,又避免了不必要的张量重复,性能更好。

Numpy 对应实现

Numpy用户遇到的问题类似,同样可以用两种方式解决:

方法1:调整形状适配批量乘法

import numpy as np

M, N, B = 5, 3, 4
matrix = np.random.randn(M, N)
batch = np.random.randn(N, B)

# 调整形状
matrix_batched = np.tile(matrix[np.newaxis, :, :], (B, 1, 1))  # shape: (B, M, N)
batch_batched = batch.T[:, :, np.newaxis]  # shape: (B, N, 1)

# 批量乘法
result_bmm = np.matmul(matrix_batched, batch_batched)  # shape: (B, M, 1)
final_result = result_bmm.squeeze(-1).T  # shape: (M, B)

方法2:直接用矩阵乘法(推荐)

Numpy的@运算符或者np.matmul同样支持这种维度的乘法,一步到位:

final_result = matrix @ batch  # shape: (M, B)
# 或者
final_result = np.matmul(matrix, batch)

总结一下,如果你只是需要完成matrix @ batch的运算,直接用matmul或者@运算符是最优解;如果一定要用bmm(比如有特殊批量场景需求),就通过调整形状来适配即可。

内容的提问来源于stack exchange,提问作者Nick

火山引擎 最新活动