PyTorch转Keras后NCHW格式TensorFlow图转NHWC适配TFLite方法咨询
解决NCHW格式TensorFlow计算图转NHWC适配TFLite的方案
我之前处理过PyTorch转Keras后因NCHW格式导致TFLite转换失败的问题,给你几个实用的解决思路:
1. 在Keras模型层面插入维度转换层(最直接)
既然你的Keras模型是NCHW格式,最省心的办法是给模型包裹一层维度转换的Lambda层,把输入输出都转成NHWC格式,这样整个计算图就适配TFLite的要求了。
举个代码例子:
import tensorflow as tf from tensorflow.keras import Model, Input, layers # 假设你已经得到了NCHW格式的Keras模型:model_nchw # 定义NCHW格式的输入(比如输入是3通道224x224的图像) nchw_input = Input(shape=(3, 224, 224)) # 把输入从NCHW转成NHWC:perm=[0,2,3,1]对应批量、高度、宽度、通道 nhwc_input = layers.Lambda(lambda x: tf.transpose(x, perm=[0, 2, 3, 1]))(nchw_input) # 让原模型处理转换后的输入 model_output = model_nchw(nhwc_input) # 如果原模型的输出也是NCHW格式,同样转成NHWC(如果你的业务不需要可以跳过这步) nhwc_output = layers.Lambda(lambda x: tf.transpose(x, perm=[0, 2, 3, 1]))(model_output) # 构建新的NHWC格式模型 model_nhwc = Model(inputs=nchw_input, outputs=nhwc_output) # 验证一下输出维度是否正确 model_nhwc.summary()
这个方法的好处是不需要修改原模型的结构,只做外层封装,而且转换后可以直接测试模型输出和原模型是否一致,避免出错。
2. 修改TensorFlow计算图节点(适合已导出的计算图)
如果已经导出了TensorFlow计算图(比如.pb文件),可以通过遍历图中的节点,手动插入transpose操作来转换格式。不过这个方法相对繁琐,需要对TensorFlow计算图结构有一定了解:
- 找到输入节点,在其后添加
tf.transpose节点将NCHW转NHWC; - 找到输出节点,在其前添加
tf.transpose节点将NCHW转NHWC; - 重新保存修改后的计算图。
举个简化的代码示例:
import tensorflow as tf # 加载已有的NCHW格式计算图 with tf.io.gfile.GFile('your_nchw_graph.pb', 'rb') as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) # 添加输入转NHWC的节点 input_node_name = 'your_input_node' # 替换成你的输入节点名 transpose_input = tf.constant([0, 2, 3, 1], name='transpose_perm') tf.import_graph_def(graph_def, name='') with tf.compat.v1.Session() as sess: input_tensor = sess.graph.get_tensor_by_name(f'{input_node_name}:0') nhwc_input = tf.transpose(input_tensor, perm=transpose_input, name='nhwc_input') # 这里需要把原模型的后续节点连接到nhwc_input,然后处理输出... # 最后保存修改后的图 output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( sess, sess.graph_def, ['your_output_node'] # 替换成你的输出节点名 ) with tf.io.gfile.GFile('your_nhwc_graph.pb', 'wb') as f: f.write(output_graph_def.SerializeToString())
这个方法适合已经导出计算图不想重新转换Keras模型的场景,但需要精准定位输入输出节点,调试成本稍高。
3. 转换TFLite时的特殊处理(部分场景可用)
有些情况下,你可以在TFLite转换时尝试添加参数强制处理格式,但这个方法并不通用,比如:
converter = tf.lite.TFLiteConverter.from_keras_model(model_nchw) converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] # 尝试设置输入格式 converter.inference_input_type = tf.float32 # 转换 tflite_model = converter.convert()
但这种方法大概率还是会失败,因为TFLite对NCHW格式的支持有限,所以还是推荐前两种方法。
最后提醒一下,转换完成后一定要对比原模型和新模型的输出结果,确保维度转换没有导致计算错误哦!
内容的提问来源于stack exchange,提问作者Abhishek Sehgal




