PyTorch报错:'torch.FloatTensor'对象无item属性求助
解决PyTorch中
AttributeError: 'torch.FloatTensor' object has no attribute 'item'的问题 这个坑我之前也踩过,来给你理清楚原因和解决办法:
错误原因
你遇到的报错本质上是两个可能的原因:
- PyTorch版本过低:
item()方法是PyTorch 0.4.0版本才新增的特性,专门用来从标量张量里提取Python原生数值。如果你的版本比0.4.0老,自然找不到这个方法。 - 张量形状不匹配:即使版本够,你代码里的
b_target = torch.randn(1) *5生成的是形状为(1,)的1维张量,而早期版本的item()只支持0维的标量张量,对1维张量调用就会报错。
解决方案
方案1:升级PyTorch(推荐)
如果你的项目没有版本限制,直接升级到0.4.0及以上版本就能解决问题。新版本不仅支持标量张量用item(),连形状为(1,)的1维张量也能直接调用这个方法。升级命令根据你的包管理器选择:
# pip用户 pip install --upgrade torch # conda用户 conda update pytorch
方案2:不升级版本,修改代码
如果不能升级PyTorch,把b_target.item()换成下面任意一种写法都能正常运行:
- 直接索引取值:利用张量的索引特性获取第一个元素
def f(x): return x.mm(W_target) + b_target[0] - 转成numpy数组后取值:先把张量转成numpy数组,再取第一个元素
def f(x): return x.mm(W_target) + b_target.numpy()[0] - 直接转成Python数值:把张量强制转成float类型
def f(x): return x.mm(W_target) + float(b_target)
随便选一种替换后,你的代码就能正常运行啦~
内容的提问来源于stack exchange,提问作者Bobo Xi




