如何在TensorFlow中复制已训练模型对象?解决copy.deepcopy失效问题
解决TensorFlow模型无法用
copy.deepcopy()复制的问题 你遇到的问题很常见——copy.deepcopy()没法直接复制TensorFlow模型对象,核心原因是:
- 模型里的
tf.Variable、tf.placeholder都是TensorFlow计算图的节点,它们的状态(比如训练好的权重)存储在TensorFlow会话(sess)中,而不是单纯的Python对象属性。 - 深拷贝只能复制Python层面的对象引用,没法同步计算图节点的关联关系,也没法把会话里的参数值转移到新对象中。
通用解决方案:手动实现模型克隆
我们可以给Model类添加一个clone方法,通过重新构建模型结构+复制参数值的方式实现完整的模型复制。具体步骤是:
- 创建一个全新的
Model实例(重新构建所有计算图节点) - 从原模型中读取训练好的参数值
- 在会话中将这些值赋值给新模型的对应变量
修改后的完整代码
from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import sys from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf FLAGS = None class Model(): def __init__(self): self.x = tf.placeholder(tf.float32, [None, 784]) self.W = tf.Variable(tf.zeros([784, 10])) self.b = tf.Variable(tf.zeros([10])) self.y = tf.matmul(self.x, self.W) + self.b self.y_ = tf.placeholder(tf.float32, [None, 10]) self.cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y_, logits=self.y)) self.train_step = tf.train.GradientDescentOptimizer(0.5).minimize(self.cross_entropy) # 收集所有需要复制的可训练变量(这里是W和b) self._trainable_vars = {var.name: var for var in [self.W, self.b]} def train(self, mnist, sess): for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(self.train_step, feed_dict={self.x: batch_xs, self.y_: batch_ys}) def test(self, mnist, sess): correct_prediction = tf.equal(tf.argmax(self.y, 1), tf.argmax(self.y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(f"模型准确率: {sess.run(accuracy, feed_dict={self.x: mnist.test.images, self.y_: mnist.test.labels})}") def clone(self, sess): """克隆当前模型,复制所有训练好的参数""" # 创建新的模型实例(重新构建计算图节点) cloned_model = Model() # 读取原模型的参数值 original_var_values = sess.run(self._trainable_vars) # 将参数值赋值给新模型的变量 for var_name, var in cloned_model._trainable_vars.items(): sess.run(var.assign(original_var_values[var_name])) return cloned_model def main(_): # 导入数据 mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) m = Model() sess = tf.InteractiveSession() tf.global_variables_initializer().run() # 训练原模型 m.train(mnist, sess) # 克隆模型(替换原来的copy.deepcopy) copy_of_m = m.clone(sess) # 测试两个模型的准确率(应该完全一致) print("原模型测试结果:") m.test(mnist, sess) print("克隆模型测试结果:") copy_of_m.test(mnist, sess) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', help='Directory for storing input data') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
为什么这个方法有效?
- 新模型会创建一套独立的计算图节点,避免和原模型的节点产生命名或依赖冲突
- 通过
var.assign()方法,我们直接将原模型训练好的权重值写入新模型的变量中,保证参数完全一致 - 克隆后的模型可以独立使用(比如继续训练、测试),不会影响原模型的状态
扩展方案:使用tf.train.Saver
如果你的模型结构复杂,也可以用TensorFlow的Saver类来实现模型的保存与加载:
- 用
tf.train.Saver()保存原模型的参数到磁盘 - 创建新的
Model实例,再用Saver加载参数到新模型中
不过对于这种简单模型,上面的clone方法更轻量直接。
内容的提问来源于stack exchange,提问作者ViniciusArruda




