You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch转Keras后NCHW格式TensorFlow图转NHWC适配TFLite方法咨询

解决NCHW格式TensorFlow计算图转NHWC适配TFLite的方案

我之前处理过PyTorch转Keras后因NCHW格式导致TFLite转换失败的问题,给你几个实用的解决思路:

1. 在Keras模型层面插入维度转换层(最直接)

既然你的Keras模型是NCHW格式,最省心的办法是给模型包裹一层维度转换的Lambda层,把输入输出都转成NHWC格式,这样整个计算图就适配TFLite的要求了。

举个代码例子:

import tensorflow as tf
from tensorflow.keras import Model, Input, layers

# 假设你已经得到了NCHW格式的Keras模型:model_nchw
# 定义NCHW格式的输入(比如输入是3通道224x224的图像)
nchw_input = Input(shape=(3, 224, 224))

# 把输入从NCHW转成NHWC:perm=[0,2,3,1]对应批量、高度、宽度、通道
nhwc_input = layers.Lambda(lambda x: tf.transpose(x, perm=[0, 2, 3, 1]))(nchw_input)

# 让原模型处理转换后的输入
model_output = model_nchw(nhwc_input)

# 如果原模型的输出也是NCHW格式,同样转成NHWC(如果你的业务不需要可以跳过这步)
nhwc_output = layers.Lambda(lambda x: tf.transpose(x, perm=[0, 2, 3, 1]))(model_output)

# 构建新的NHWC格式模型
model_nhwc = Model(inputs=nchw_input, outputs=nhwc_output)

# 验证一下输出维度是否正确
model_nhwc.summary()

这个方法的好处是不需要修改原模型的结构,只做外层封装,而且转换后可以直接测试模型输出和原模型是否一致,避免出错。

2. 修改TensorFlow计算图节点(适合已导出的计算图)

如果已经导出了TensorFlow计算图(比如.pb文件),可以通过遍历图中的节点,手动插入transpose操作来转换格式。不过这个方法相对繁琐,需要对TensorFlow计算图结构有一定了解:

  • 找到输入节点,在其后添加tf.transpose节点将NCHW转NHWC;
  • 找到输出节点,在其前添加tf.transpose节点将NCHW转NHWC;
  • 重新保存修改后的计算图。

举个简化的代码示例:

import tensorflow as tf

# 加载已有的NCHW格式计算图
with tf.io.gfile.GFile('your_nchw_graph.pb', 'rb') as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())

# 添加输入转NHWC的节点
input_node_name = 'your_input_node'  # 替换成你的输入节点名
transpose_input = tf.constant([0, 2, 3, 1], name='transpose_perm')
tf.import_graph_def(graph_def, name='')
with tf.compat.v1.Session() as sess:
    input_tensor = sess.graph.get_tensor_by_name(f'{input_node_name}:0')
    nhwc_input = tf.transpose(input_tensor, perm=transpose_input, name='nhwc_input')
    # 这里需要把原模型的后续节点连接到nhwc_input,然后处理输出...
    # 最后保存修改后的图
    output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
        sess, sess.graph_def, ['your_output_node']  # 替换成你的输出节点名
    )
    with tf.io.gfile.GFile('your_nhwc_graph.pb', 'wb') as f:
        f.write(output_graph_def.SerializeToString())

这个方法适合已经导出计算图不想重新转换Keras模型的场景,但需要精准定位输入输出节点,调试成本稍高。

3. 转换TFLite时的特殊处理(部分场景可用)

有些情况下,你可以在TFLite转换时尝试添加参数强制处理格式,但这个方法并不通用,比如:

converter = tf.lite.TFLiteConverter.from_keras_model(model_nchw)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
# 尝试设置输入格式
converter.inference_input_type = tf.float32
# 转换
tflite_model = converter.convert()

但这种方法大概率还是会失败,因为TFLite对NCHW格式的支持有限,所以还是推荐前两种方法。

最后提醒一下,转换完成后一定要对比原模型和新模型的输出结果,确保维度转换没有导致计算错误哦!

内容的提问来源于stack exchange,提问作者Abhishek Sehgal

火山引擎 最新活动