You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

将Detectron2的.h5模型转为.tflite时触发AttributeError错误,求解决方案

将Detectron2的.h5模型转为.tflite时触发AttributeError错误,求解决方案

兄弟,你这问题核心是混淆了不同深度学习框架的模型格式,用错了转换工具,咱们一步步来解决:

先捋清楚你代码里的问题

你现在的代码同时混用了PyTorch/Detectron2和TensorFlow/Keras的逻辑,还犯了一个关键错误:
Detectron2的.h5文件只是PyTorch模型的权重文件,不是Keras框架的完整模型文件,但你却用了TensorFlow的TFLiteConverter.from_keras_model去加载它——这个方法只接受Keras模型对象或者TensorFlow SavedModel的路径,自然会触发AttributeError。

正确的转换流程:Detectron2(PyTorch)→ ONNX → TFLite

因为Detectron2是基于PyTorch的,没法直接用TensorFlow的工具转TFLite,得先把PyTorch模型转成ONNX格式,再转成TFLite,具体步骤如下:

步骤1:加载Detectron2模型并导出为ONNX

首先确保你有模型对应的配置文件(比如训练时用的config.yaml),然后执行以下代码:

import torch
from detectron2.config import get_cfg
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer

# 配置模型参数
cfg = get_cfg()
cfg.merge_from_file("path/to/your/training_config.yaml")  # 替换成你自己的配置文件路径
cfg.MODEL.WEIGHTS = "/kaggle/input/vehicleobjectdetection/pytorch/v1/2/model_final.h5"
cfg.MODEL.DEVICE = "cpu"  # 用CPU导出避免CUDA兼容性问题
cfg.freeze()

# 加载模型并切换到评估模式
model = build_model(cfg)
model.eval()
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)

# 创建示例输入(尺寸要和你训练时的输入一致,比如COCO默认是(3, 800, 1024))
example_input = torch.randn(1, 3, 800, 1024)  # batch_size=1, 通道数3,高800,宽1024

# 导出ONNX模型
torch.onnx.export(
    model,
    example_input,
    "detectron2_model.onnx",
    opset_version=11,  # 选兼容性较好的opset版本
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}  # 可选:支持动态batch大小
)

步骤2:把ONNX模型转成TFLite

接下来用ONNX转TensorFlow的工具,再导出TFLite:

import tensorflow as tf
import onnx
from onnx_tf.backend import prepare

# 加载ONNX模型
onnx_model = onnx.load("detectron2_model.onnx")
tf_rep = prepare(onnx_model)

# 先导出为TensorFlow SavedModel格式
tf_rep.export_graph("detectron2_saved_model")

# 转换为TFLite
converter = tf.lite.TFLiteConverter.from_saved_model("detectron2_saved_model")
tflite_model = converter.convert()

# 保存最终的TFLite模型
with open("detectron2_model.tflite", "wb") as f:
    f.write(tflite_model)

额外注意事项

  • 确保安装了所有依赖:pip install onnx onnx-tf tensorflow torch detectron2
  • 示例输入的尺寸必须和你训练模型时的输入尺寸匹配,否则转换后的模型可能无法正常工作
  • 如果你的模型包含自定义层,可能需要额外处理ONNX导出的兼容性问题,比如注册自定义层的ONNX转换器

备注:内容来源于stack exchange,提问作者Muhammad Ammar

火山引擎 最新活动