如何在Keras中分别更新两个关联CNN模型并单独保存权重?
解决Keras中串联CNN模型的分步更新与权重保存问题
你遇到的核心问题是:直接用model.compile()和fit_generator()(现在Keras的fit()已支持生成器)没法实现每次迭代都用model1当前更新后的输出来训练model2,常规fit流程会固定模型的输入输出关系,没法实时串联两个模型的动态更新。下面给你一套可行的解决方案,用自定义训练循环来精准实现需求:
核心思路
放弃Keras高层的fit()方法,改用TensorFlow低级API自定义训练流程,这样能完全控制每一步的权重更新逻辑:
- 为两个模型分别定义独立的优化器和损失函数
- 每次迭代中,先更新model1的权重,再用更新后的model1生成实时输出,作为model2的输入来更新model2
- 全程手动控制权重保存,确保两个模型的权重独立存储
具体实现代码
1. 定义两个模型结构
确保model2的输入形状和model1的输出形状完全匹配:
import tensorflow as tf from tensorflow.keras import layers, Model # 构建model1:输入input1,输出对应你的output标签 def build_model1(input_shape): inputs = layers.Input(shape=input_shape) x = layers.Conv2D(32, (3,3), activation='relu', padding='same')(inputs) x = layers.MaxPooling2D((2,2))(x) x = layers.Conv2D(64, (3,3), activation='relu', padding='same')(x) x = layers.MaxPooling2D((2,2))(x) x = layers.Flatten()(x) outputs = layers.Dense(10, activation='softmax')(x) # 替换成你实际的输出维度和激活函数 return Model(inputs, outputs, name='model1') # 构建model2:输入是model1的输出,输出对应你的output2标签 def build_model2(input_shape): inputs = layers.Input(shape=input_shape) x = layers.Dense(64, activation='relu')(inputs) x = layers.Dropout(0.2)(x) outputs = layers.Dense(5, activation='sigmoid')(x) # 替换成你实际的输出维度和激活函数 return Model(inputs, outputs, name='model2') # 初始化模型(替换成你实际的输入形状) model1 = build_model1((28, 28, 1)) model2 = build_model2(model1.output_shape[1:]) # 自动匹配model1的输出形状
2. 定义独立的优化器和损失函数
两个模型用各自的优化器和损失,互不干扰:
# 为每个模型单独设置优化器(可根据需求调整学习率、优化器类型) optimizer1 = tf.keras.optimizers.Adam(learning_rate=1e-3) optimizer2 = tf.keras.optimizers.SGD(learning_rate=1e-2, momentum=0.9) # 定义对应任务的损失函数(替换成你实际的损失类型) loss_fn1 = tf.keras.losses.SparseCategoricalCrossentropy() # 适合单分类任务 loss_fn2 = tf.keras.losses.BinaryCrossentropy() # 适合多标签分类任务
3. 自定义训练循环
这是实现需求的关键部分,每一步都手动控制两个模型的更新:
# 替换成你的数据生成器,要求每个batch返回 (input1, label1, label2) # label1是model1的目标输出,label2是model2的目标输出 def data_generator(batch_size): while True: # 这里只是示例,替换成你真实的数据加载逻辑 input1 = tf.random.normal((batch_size, 28, 28, 1)) label1 = tf.random.uniform((batch_size,), maxval=10, dtype=tf.int32) label2 = tf.random.uniform((batch_size, 5), maxval=2, dtype=tf.int32) yield input1, label1, label2 # 训练参数设置(替换成你的实际参数) batch_size = 32 epochs = 500 steps_per_epoch = 100 # 每个epoch的batch数量 # 开始训练循环 for epoch in range(epochs): print(f"=== Epoch {epoch+1}/{epochs} ===") total_loss1 = 0.0 total_loss2 = 0.0 for step in range(steps_per_epoch): # 获取当前batch的数据 input1, label1, label2 = next(data_generator(batch_size)) # -------------------------- 训练model1 -------------------------- with tf.GradientTape() as tape1: # 前向传播,training=True开启训练模式(激活Dropout、BN等层的训练行为) pred1 = model1(input1, training=True) # 计算model1的损失 loss1 = loss_fn1(label1, pred1) # 计算梯度并更新model1的权重 grads1 = tape1.gradient(loss1, model1.trainable_variables) optimizer1.apply_gradients(zip(grads1, model1.trainable_variables)) # -------------------------- 训练model2 -------------------------- with tf.GradientTape() as tape2: # 用更新后的model1生成实时输出,training=False关闭训练模式(避免再次更新model1) intermediate_output = model1(input1, training=False) # 前向传播训练model2 pred2 = model2(intermediate_output, training=True) # 计算model2的损失 loss2 = loss_fn2(label2, pred2) # 计算梯度并更新model2的权重 grads2 = tape2.gradient(loss2, model2.trainable_variables) optimizer2.apply_gradients(zip(grads2, model2.trainable_variables)) # 累加损失用于统计 total_loss1 += loss1.numpy() total_loss2 += loss2.numpy() # 打印当前epoch的平均损失 avg_loss1 = total_loss1 / steps_per_epoch avg_loss2 = total_loss2 / steps_per_epoch print(f"Model1 平均损失: {avg_loss1:.4f} | Model2 平均损失: {avg_loss2:.4f}") # 定期保存权重(比如每10个epoch保存一次) if (epoch + 1) % 10 == 0: model1.save_weights(f"model1_epoch_{epoch+1}.h5") model2.save_weights(f"model2_epoch_{epoch+1}.h5") print(f"已保存第{epoch+1}轮权重") # 训练结束后保存最终权重 model1.save_weights("model1_final.h5") model2.save_weights("model2_final.h5") print("训练完成,已保存最终权重")
关键注意事项
- training参数的设置:训练model2时,model1的
training要设为False,这样既不会再次更新model1的权重,也能让model1中的BatchNormalization、Dropout等层使用训练阶段统计的均值/方差,保证输出稳定。 - 数据生成器要求:你的数据生成器必须同时提供三个数据:
input1(model1的输入)、label1(model1的目标输出)、label2(model2的目标输出),因为两个模型各自有独立的训练目标。 - 权重加载:后续如果要加载权重,直接用
model1.load_weights("xxx.h5")和model2.load_weights("xxx.h5")即可,无需重新编译模型。
内容的提问来源于stack exchange,提问作者S.shin




