You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch报错:'torch.FloatTensor'对象无item属性求助

解决PyTorch中AttributeError: 'torch.FloatTensor' object has no attribute 'item'的问题

这个坑我之前也踩过,来给你理清楚原因和解决办法:

错误原因

你遇到的报错本质上是两个可能的原因:

  • PyTorch版本过低item()方法是PyTorch 0.4.0版本才新增的特性,专门用来从标量张量里提取Python原生数值。如果你的版本比0.4.0老,自然找不到这个方法。
  • 张量形状不匹配:即使版本够,你代码里的b_target = torch.randn(1) *5生成的是形状为(1,)的1维张量,而早期版本的item()只支持0维的标量张量,对1维张量调用就会报错。

解决方案

方案1:升级PyTorch(推荐)

如果你的项目没有版本限制,直接升级到0.4.0及以上版本就能解决问题。新版本不仅支持标量张量用item(),连形状为(1,)的1维张量也能直接调用这个方法。升级命令根据你的包管理器选择:

# pip用户
pip install --upgrade torch

# conda用户
conda update pytorch

方案2:不升级版本,修改代码

如果不能升级PyTorch,把b_target.item()换成下面任意一种写法都能正常运行:

  • 直接索引取值:利用张量的索引特性获取第一个元素
    def f(x):
        return x.mm(W_target) + b_target[0]
    
  • 转成numpy数组后取值:先把张量转成numpy数组,再取第一个元素
    def f(x):
        return x.mm(W_target) + b_target.numpy()[0]
    
  • 直接转成Python数值:把张量强制转成float类型
    def f(x):
        return x.mm(W_target) + float(b_target)
    

随便选一种替换后,你的代码就能正常运行啦~

内容的提问来源于stack exchange,提问作者Bobo Xi

火山引擎 最新活动