You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch中添加[:]才可复制ResNet18预训练权重的原因及作用

为什么在PyTorch中替换模型权重时需要使用[:]?

这是个非常典型的PyTorch张量操作细节问题,我来给你拆解清楚:

先回顾你的场景:你有一个预训练好的ResNet18,还有一个自定义初始化的reg_resnet模型,想要把预训练权重迁移到自定义模型上。带[:]的代码能成功替换权重,直接赋值却不行,核心原因在于PyTorch中state_dict()的本质,以及两种赋值方式的差异。

1. 先搞懂state_dict()到底是什么

PyTorch模型的state_dict()返回的是一个有序字典,里面存的是模型所有可学习参数(权重、偏置等)的张量引用。但要注意:当你直接通过reg_resnet.state_dict()[each_param]取到这个张量时,它其实是一个「不可直接替换引用的视图」——你没法通过直接赋值把这个字典里的引用换成另一个张量,来改变模型的实际参数。

2. 直接用=赋值为什么无效?

当你写reg_resnet.state_dict()[each_param] = resnet18.state_dict()[each_param]时,你只是把state_dict这个字典里对应key的引用值给换掉了,但模型实际在训练、推理时用的权重,是存在它自身的参数列表(比如reg_resnet.parameters()返回的那些张量)里的。state_dict只是一个实时的参数快照字典,你替换字典里的引用,根本不会同步到模型的实际参数张量上,所以这种操作完全是做无用功。

3. [:]切片赋值到底起了什么作用?

reg_resnet.state_dict()[each_param][:] = resnet18.state_dict()[each_param]这种写法,是对张量的内容进行原地修改

  • [:]相当于选中了张量的所有元素,它会绕过“替换引用”的操作,直接访问张量的底层内存;
  • 这个操作会把resnet18对应张量里的数值,原地拷贝reg_resnet的参数张量内存中;
  • 因为state_dict里的引用和模型实际参数指向的是同一块内存,所以模型的实际权重也就被修改了,这才完成了真正的权重替换。

额外推荐:更规范的权重迁移方式

其实PyTorch官方提供了更可靠的方法,不需要手动循环赋值:

resnet18 = models.resnet18(pretrained=True)
reg_resnet = resnet_model()
# 用load_state_dict自动完成权重赋值,strict=False允许模型结构有少量差异
reg_resnet.load_state_dict(resnet18.state_dict(), strict=False)

这个方法会自动处理张量的赋值逻辑,比手动循环更简洁也更不容易出错。

内容的提问来源于stack exchange,提问作者cmed123

火山引擎 最新活动