You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

TensorFlow模型转TensorFlow Lite后输入输出形状不匹配求助

解决TensorFlow Lite转换后输入输出形状不匹配的问题

我看你遇到的核心问题是模型路径不匹配,外加可能没有显式指定输入形状导致的,下面一步步帮你解决:

1. 修正最明显的路径错误

你下载并解压的是ssd_resnet50_v1_fpn_640x640_coco17_tpu-8模型,但转换时却指向了ssd_mobilenet_v2_320x320_coco17_tpu-8/saved_model——这两个模型的输入尺寸(640x640 vs 320x320)完全不同,自然会导致输入输出形状不匹配。

2. 显式指定输入形状(推荐)

目标检测模型常有动态维度,TFLite转换器有时无法自动推断正确的输入形状,建议手动指定:

  • 先查看原模型的输入节点名称和形状:
import tensorflow as tf
loaded_model = tf.saved_model.load('ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/saved_model')
in_tensor = list(loaded_model.signatures['serving_default'].inputs)[0]
print(f"输入节点名称: {in_tensor.name}, 形状: {in_tensor.shape}")
  • 初始化转换器时带上输入形状参数

3. 修正后的完整转换代码

!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!tar -xzvf ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!pip install tensorflow>=2.4  # 推荐用稳定版,而非tf-nightly

import tensorflow as tf

# 确认输入节点信息
loaded_model = tf.saved_model.load('ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/saved_model')
in_tensor = list(loaded_model.signatures['serving_default'].inputs)[0]
input_node_name = in_tensor.name.split(':')[0]  # 提取节点名(去掉":0"后缀)

# 初始化转换器,指定正确路径和输入形状
converter = tf.lite.TFLiteConverter.from_saved_model(
    'ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/saved_model',
    input_shapes={input_node_name: [1, 640, 640, 3]}  # 对应原模型的输入尺寸
)

# 配置转换参数
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]

# 执行转换并保存
tflite_model = converter.convert()
with open("ssd_resnet50_fpn_640x640.tflite", "wb") as f:
    f.write(tflite_model)

4. 验证转换后的模型形状

转换完成后,可以用TFLite解释器检查输入输出是否和原模型一致:

interpreter = tf.lite.Interpreter(model_path="ssd_resnet50_fpn_640x640.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("输入形状:", input_details[0]['shape'])
print("输出形状:", [out['shape'] for out in output_details])

内容的提问来源于stack exchange,提问作者H.H

火山引擎 最新活动