You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

不创建Sequential模型,如何向预训练ResNet50添加层以保留原模型各层级可访问性

解决ResNet50微调时保留内部层级访问的问题

你遇到的问题确实很典型——用Sequential包裹预训练模型时,它会把整个ResNet50当作一个单一的"层",导致内部结构被封装无法访问。要解决这个问题,改用Keras的函数式API就可以完美保留原始模型的所有层级结构,同时灵活添加自定义层。

具体实现步骤

函数式API是Keras中构建复杂模型的首选方式,尤其适合这种拼接预训练模型和自定义层的场景:

  1. 加载预训练ResNet50(不带顶层)
    首先加载ResNet50时,记得设置include_top=False,这样会去掉原始的分类顶层,只保留特征提取部分:

    from tensorflow.keras.applications import ResNet50
    from tensorflow.keras import layers, Model
    
    # 加载预训练模型,去掉顶层分类层
    resnet50 = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    
  2. 用函数式API拼接自定义层
    直接基于ResNet50的输出,逐步添加你的全连接层:

    # 获取ResNet50的输出特征
    x = resnet50.output
    # 添加全局平均池化(因为ResNet50输出是特征图,需要转换成一维向量)
    x = layers.GlobalAveragePooling2D()(x)
    # 添加自定义隐藏层
    x = layers.Dense(1024, activation='relu', name='hidden_layer')(x)
    # 添加二分类多标签的输出层(sigmoid激活)
    output = layers.Dense(2, activation='sigmoid', name='output')(x)
    
    # 构建完整模型
    self.model = Model(inputs=resnet50.input, outputs=output)
    
  3. 验证层级可访问性
    现在运行self.model.summary(),你会看到ResNet50的所有内部层(比如conv1_convconv2_block1_1_conv等)和你添加的自定义层都清晰列出。你可以通过以下方式访问任意内部层:

    # 获取ResNet50模型实例
    resnet_inner = self.model.get_layer('resnet50')
    # 获取ResNet内部的某个具体层
    conv5_layer = resnet_inner.get_layer('conv5_block3_out')
    

为什么这样做更适合你的需求?

  • 保留层级结构:函数式API不会封装预训练模型的内部结构,所有层都处于同一个模型图中,方便后续提取显著性图时访问中间层或计算输入的梯度。
  • 灵活微调:你可以轻松控制ResNet50的可训练状态,比如先冻结所有ResNet层训练自定义顶层,再解冻部分深层进行微调:
    # 先冻结ResNet50的所有层
    resnet50.trainable = False
    # 编译模型并训练顶层
    self.model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    self.model.fit(...)
    
    # 微调阶段:解冻ResNet50的部分层
    resnet50.trainable = True
    # 比如只解冻最后3个卷积块
    for layer in resnet50.layers[:-10]:
        layer.trainable = False
    # 重新编译(用更小的学习率)
    self.model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), loss='binary_crossentropy', metrics=['accuracy'])
    self.model.fit(...)
    

对提取Saliency Maps的帮助

因为现在你可以直接访问模型的输入层和任意中间层,计算显著性图时(比如通过输入对输出的梯度)会非常方便。例如,你可以用TensorFlow的梯度带计算输入图像对应某个输出类别的梯度,进而生成显著性图——而这一切都依赖于模型保留了完整的层级结构,没有被封装成黑盒。

内容的提问来源于stack exchange,提问作者eidnas

火山引擎 最新活动