You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

基于ImageAI训练YOLOv3时如何保存与加载模型状态?

解决ImageAI训练YOLOv3时的断点续训问题

嘿,我来帮你搞定Colab训练时长不够的问题!ImageAI其实已经内置了TensorFlow的模型断点续存功能,完全不用你手动写复杂的checkpoint逻辑,刚好适配你的需求。

一、修改代码开启自动保存Checkpoint

你的现有代码只需要添加几个参数,就能让训练过程自动保存模型状态。直接看修改后的代码:

from imageai.Detection.Custom import DetectionModelTrainer

trainer = DetectionModelTrainer()
trainer.setModelTypeAsYOLOv3()
trainer.setDataDirectory(data_directory="/content/drive/My Drive/Dataset")

# 配置训练参数,新增checkpoint相关设置
trainer.setTrainConfig(
    object_names_array=["obj1","obj2"],
    batch_size=4,
    num_experiments=421,  # 你的总训练epoch数保持不变
    save_model_every_number_of_epochs=10,  # 每10个epoch自动保存一次模型
    # 第一次训练时可以删掉下面这行,续训时再填入路径
    continue_from_model="/content/drive/My Drive/Dataset/models/detection_model-ex50.h5"
)

trainer.trainModel()

参数说明:

  • save_model_every_number_of_epochs:设置每隔多少个epoch保存一次模型状态。比如设为10,每跑完10个epoch,就会在你数据集目录下的models文件夹里生成类似detection_model-ex10.h5detection_model-ex20.h5的文件,这就是你的训练断点文件。
  • continue_from_model第一次训练时可以删除这个参数或者设为None,等需要续训时,把这个参数值改成你上次保存的最新checkpoint文件的完整路径(比如你上次跑到50个epoch,就填detection_model-ex50.h5的路径)。

二、续训的具体操作

  1. 当Colab会话断开或训练中途停止后,打开你存在Google Drive里的数据集目录,找到models文件夹,定位到最新的那个checkpoint文件(比如detection_model-ex50.h5)。
  2. 修改代码,把continue_from_model参数设置为这个文件的完整路径,同时保持num_experiments为你原本设定的总epoch数(比如还是421)。
  3. 重新运行代码,训练就会从50个epoch的进度继续往下跑,不用从头开始浪费时间。

三、关键注意事项

  • 务必把数据集和生成的checkpoint文件都存在Google Drive里!Colab的临时存储空间会在会话结束后清空,存在Drive里才能永久保留训练进度。
  • 如果你担心中途断开丢失太多进度,可以把save_model_every_number_of_epochs设小一点(比如5),不过太频繁会占用更多存储空间,按需调整即可。
  • 每个checkpoint文件都包含了当前训练的权重、优化器状态等完整信息,续训时能完全衔接之前的训练效果,不会出现断层。

内容的提问来源于stack exchange,提问作者Tech-D

火山引擎 最新活动