如何保存调用interpreter.resize_input_tensor后固定输入形状的TFLite模型?
如何保存调用interpreter.resize_input_tensor后固定输入形状的TFLite模型?
嗨,我之前也碰到过一模一样的需求,其实TensorFlow提供了实验性API可以直接帮你把调整好输入形状的模型保存下来,不用每次加载都重复执行resize操作,下面是具体的实现方案:
核心实现思路
在你调用resize_tensor_input固定好输入形状后,使用tf.lite.experimental.write_interpreter_to_file这个工具,就能把当前解释器里已经修改好配置的模型导出成新的固定形状TFLite文件。
修改后的完整代码示例
import tensorflow as tf import numpy as np # 假设你已经完成了输入图像的预处理,得到input_image # ... 你的图像加载、预处理代码 ... TFLITE_FILE_PATH = "/home/debian/sandbox/movenet/python/tflite/1.tflite" interpreter = tf.lite.Interpreter(TFLITE_FILE_PATH) input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() print(input_details) print(output_details) is_dynamic_shape_model = input_details[0]['shape_signature'][2] == -1 if is_dynamic_shape_model: input_tensor_index = input_details[0]['index'] input_shape = input_image.shape interpreter.resize_tensor_input(input_tensor_index, input_shape, strict=True) # ------------------- 新增的模型保存代码 ------------------- SAVED_FIXED_MODEL_PATH = "/home/debian/sandbox/movenet_fixed_shape.tflite" tf.lite.experimental.write_interpreter_to_file(interpreter, SAVED_FIXED_MODEL_PATH) print(f"固定输入形状的模型已保存至: {SAVED_FIXED_MODEL_PATH}") # --------------------------------------------------------- interpreter.allocate_tensors() interpreter.set_tensor(input_details[0]['index'], input_image.numpy()) interpreter.invoke() keypoints_with_scores = interpreter.get_tensor(output_details[0]['index']) keypoints_with_scores = np.squeeze(keypoints_with_scores)
验证保存后的模型
下次加载这个新模型时,就不需要再做resize操作了,直接初始化解释器就能用:
# 加载固定形状的模型 fixed_interpreter = tf.lite.Interpreter(SAVED_FIXED_MODEL_PATH) fixed_interpreter.allocate_tensors() fixed_input_details = fixed_interpreter.get_input_details() print(f"固定模型的输入形状: {fixed_input_details[0]['shape']}") # 这里会输出你之前设置的input_image的具体形状,比如(1, 256, 256, 3) # 直接执行推理即可 fixed_interpreter.set_tensor(fixed_input_details[0]['index'], input_image.numpy()) fixed_interpreter.invoke() # ... 后续关键点处理逻辑和之前完全一致 ...
注意事项
- 确保你的TensorFlow版本是2.8及以上,这个实验性API在旧版本中可能没有;
strict=True参数会强制检查你设置的输入形状是否符合模型的约束,避免设置不兼容的形状导致报错;- 如果你需要适配多种不同的固定输入形状,需要分别保存对应形状的模型——单个TFLite模型只能绑定一种固定输入形状(除非保留动态维度)。
备注:内容来源于stack exchange,提问作者Shinobu HUYUGIRI




