You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何保存调用interpreter.resize_input_tensor后固定输入形状的TFLite模型?

如何保存调用interpreter.resize_input_tensor后固定输入形状的TFLite模型?

嗨,我之前也碰到过一模一样的需求,其实TensorFlow提供了实验性API可以直接帮你把调整好输入形状的模型保存下来,不用每次加载都重复执行resize操作,下面是具体的实现方案:

核心实现思路

在你调用resize_tensor_input固定好输入形状后,使用tf.lite.experimental.write_interpreter_to_file这个工具,就能把当前解释器里已经修改好配置的模型导出成新的固定形状TFLite文件。

修改后的完整代码示例

import tensorflow as tf
import numpy as np

# 假设你已经完成了输入图像的预处理,得到input_image
# ... 你的图像加载、预处理代码 ...

TFLITE_FILE_PATH = "/home/debian/sandbox/movenet/python/tflite/1.tflite"
interpreter = tf.lite.Interpreter(TFLITE_FILE_PATH)

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)

is_dynamic_shape_model = input_details[0]['shape_signature'][2] == -1
if is_dynamic_shape_model:
    input_tensor_index = input_details[0]['index']
    input_shape = input_image.shape
    interpreter.resize_tensor_input(input_tensor_index, input_shape, strict=True)

# ------------------- 新增的模型保存代码 -------------------
SAVED_FIXED_MODEL_PATH = "/home/debian/sandbox/movenet_fixed_shape.tflite"
tf.lite.experimental.write_interpreter_to_file(interpreter, SAVED_FIXED_MODEL_PATH)
print(f"固定输入形状的模型已保存至: {SAVED_FIXED_MODEL_PATH}")
# ---------------------------------------------------------

interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], input_image.numpy())

interpreter.invoke()

keypoints_with_scores = interpreter.get_tensor(output_details[0]['index'])
keypoints_with_scores = np.squeeze(keypoints_with_scores)

验证保存后的模型

下次加载这个新模型时,就不需要再做resize操作了,直接初始化解释器就能用:

# 加载固定形状的模型
fixed_interpreter = tf.lite.Interpreter(SAVED_FIXED_MODEL_PATH)
fixed_interpreter.allocate_tensors()

fixed_input_details = fixed_interpreter.get_input_details()
print(f"固定模型的输入形状: {fixed_input_details[0]['shape']}")
# 这里会输出你之前设置的input_image的具体形状,比如(1, 256, 256, 3)

# 直接执行推理即可
fixed_interpreter.set_tensor(fixed_input_details[0]['index'], input_image.numpy())
fixed_interpreter.invoke()
# ... 后续关键点处理逻辑和之前完全一致 ...

注意事项

  1. 确保你的TensorFlow版本是2.8及以上,这个实验性API在旧版本中可能没有;
  2. strict=True参数会强制检查你设置的输入形状是否符合模型的约束,避免设置不兼容的形状导致报错;
  3. 如果你需要适配多种不同的固定输入形状,需要分别保存对应形状的模型——单个TFLite模型只能绑定一种固定输入形状(除非保留动态维度)。

备注:内容来源于stack exchange,提问作者Shinobu HUYUGIRI

火山引擎 最新活动