You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何在Keras中分别更新两个关联CNN模型并单独保存权重?

解决Keras中串联CNN模型的分步更新与权重保存问题

你遇到的核心问题是:直接用model.compile()fit_generator()(现在Keras的fit()已支持生成器)没法实现每次迭代都用model1当前更新后的输出来训练model2,常规fit流程会固定模型的输入输出关系,没法实时串联两个模型的动态更新。下面给你一套可行的解决方案,用自定义训练循环来精准实现需求:

核心思路

放弃Keras高层的fit()方法,改用TensorFlow低级API自定义训练流程,这样能完全控制每一步的权重更新逻辑:

  • 为两个模型分别定义独立的优化器和损失函数
  • 每次迭代中,先更新model1的权重,再用更新后的model1生成实时输出,作为model2的输入来更新model2
  • 全程手动控制权重保存,确保两个模型的权重独立存储

具体实现代码

1. 定义两个模型结构

确保model2的输入形状和model1的输出形状完全匹配:

import tensorflow as tf
from tensorflow.keras import layers, Model

# 构建model1:输入input1,输出对应你的output标签
def build_model1(input_shape):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, (3,3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2,2))(x)
    x = layers.Conv2D(64, (3,3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2,2))(x)
    x = layers.Flatten()(x)
    outputs = layers.Dense(10, activation='softmax')(x)  # 替换成你实际的输出维度和激活函数
    return Model(inputs, outputs, name='model1')

# 构建model2:输入是model1的输出,输出对应你的output2标签
def build_model2(input_shape):
    inputs = layers.Input(shape=input_shape)
    x = layers.Dense(64, activation='relu')(inputs)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(5, activation='sigmoid')(x)  # 替换成你实际的输出维度和激活函数
    return Model(inputs, outputs, name='model2')

# 初始化模型(替换成你实际的输入形状)
model1 = build_model1((28, 28, 1))
model2 = build_model2(model1.output_shape[1:])  # 自动匹配model1的输出形状

2. 定义独立的优化器和损失函数

两个模型用各自的优化器和损失,互不干扰:

# 为每个模型单独设置优化器(可根据需求调整学习率、优化器类型)
optimizer1 = tf.keras.optimizers.Adam(learning_rate=1e-3)
optimizer2 = tf.keras.optimizers.SGD(learning_rate=1e-2, momentum=0.9)

# 定义对应任务的损失函数(替换成你实际的损失类型)
loss_fn1 = tf.keras.losses.SparseCategoricalCrossentropy()  # 适合单分类任务
loss_fn2 = tf.keras.losses.BinaryCrossentropy()  # 适合多标签分类任务

3. 自定义训练循环

这是实现需求的关键部分,每一步都手动控制两个模型的更新:

# 替换成你的数据生成器,要求每个batch返回 (input1, label1, label2)
# label1是model1的目标输出,label2是model2的目标输出
def data_generator(batch_size):
    while True:
        # 这里只是示例,替换成你真实的数据加载逻辑
        input1 = tf.random.normal((batch_size, 28, 28, 1))
        label1 = tf.random.uniform((batch_size,), maxval=10, dtype=tf.int32)
        label2 = tf.random.uniform((batch_size, 5), maxval=2, dtype=tf.int32)
        yield input1, label1, label2

# 训练参数设置(替换成你的实际参数)
batch_size = 32
epochs = 500
steps_per_epoch = 100  # 每个epoch的batch数量

# 开始训练循环
for epoch in range(epochs):
    print(f"=== Epoch {epoch+1}/{epochs} ===")
    total_loss1 = 0.0
    total_loss2 = 0.0

    for step in range(steps_per_epoch):
        # 获取当前batch的数据
        input1, label1, label2 = next(data_generator(batch_size))

        # -------------------------- 训练model1 --------------------------
        with tf.GradientTape() as tape1:
            # 前向传播,training=True开启训练模式(激活Dropout、BN等层的训练行为)
            pred1 = model1(input1, training=True)
            # 计算model1的损失
            loss1 = loss_fn1(label1, pred1)
        
        # 计算梯度并更新model1的权重
        grads1 = tape1.gradient(loss1, model1.trainable_variables)
        optimizer1.apply_gradients(zip(grads1, model1.trainable_variables))

        # -------------------------- 训练model2 --------------------------
        with tf.GradientTape() as tape2:
            # 用更新后的model1生成实时输出,training=False关闭训练模式(避免再次更新model1)
            intermediate_output = model1(input1, training=False)
            # 前向传播训练model2
            pred2 = model2(intermediate_output, training=True)
            # 计算model2的损失
            loss2 = loss_fn2(label2, pred2)
        
        # 计算梯度并更新model2的权重
        grads2 = tape2.gradient(loss2, model2.trainable_variables)
        optimizer2.apply_gradients(zip(grads2, model2.trainable_variables))

        # 累加损失用于统计
        total_loss1 += loss1.numpy()
        total_loss2 += loss2.numpy()

    # 打印当前epoch的平均损失
    avg_loss1 = total_loss1 / steps_per_epoch
    avg_loss2 = total_loss2 / steps_per_epoch
    print(f"Model1 平均损失: {avg_loss1:.4f} | Model2 平均损失: {avg_loss2:.4f}")

    # 定期保存权重(比如每10个epoch保存一次)
    if (epoch + 1) % 10 == 0:
        model1.save_weights(f"model1_epoch_{epoch+1}.h5")
        model2.save_weights(f"model2_epoch_{epoch+1}.h5")
        print(f"已保存第{epoch+1}轮权重")

# 训练结束后保存最终权重
model1.save_weights("model1_final.h5")
model2.save_weights("model2_final.h5")
print("训练完成,已保存最终权重")

关键注意事项

  • training参数的设置:训练model2时,model1的training要设为False,这样既不会再次更新model1的权重,也能让model1中的BatchNormalization、Dropout等层使用训练阶段统计的均值/方差,保证输出稳定。
  • 数据生成器要求:你的数据生成器必须同时提供三个数据:input1(model1的输入)、label1(model1的目标输出)、label2(model2的目标输出),因为两个模型各自有独立的训练目标。
  • 权重加载:后续如果要加载权重,直接用model1.load_weights("xxx.h5")model2.load_weights("xxx.h5")即可,无需重新编译模型。

内容的提问来源于stack exchange,提问作者S.shin

火山引擎 最新活动