You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

TensorFlow.js(tfjs):已训练模型的保存及后续应用调用方法问询

保存与加载TensorFlow.js训练好的模型

嘿,恭喜你已经用TensorFlow.js训练出属于自己的模型啦!下面我就详细讲讲怎么把这个训练好的模型保存下来,以及后续在应用里怎么加载使用它——分浏览器和Node.js两种常见场景来说,你可以按需参考~


一、保存训练好的模型

1. 浏览器环境下的保存

浏览器里有两种常用的保存方式,看你的需求选:

  • 保存到本地文件系统:适合需要把模型下载到本地备份,或者部署到其他地方的场景。直接调用模型的save方法,指定downloads://前缀:

    // 假设你的模型变量叫model
    model.save('downloads://my-trained-model');
    

    执行后浏览器会自动下载两个类型的文件:一个model.json(包含模型的拓扑结构、权重索引等元信息),以及一组.bin格式的权重文件。

  • 保存到浏览器IndexedDB:适合需要在当前浏览器域名下持久化模型的场景(比如用户刷新页面后还能复用模型,不用重新训练)。用indexeddb://前缀:

    await model.save('indexeddb://my-trained-model');
    

    这个操作会把模型存在浏览器的IndexedDB数据库里,同一个域名下后续可以直接读取。

2. Node.js环境下的保存

如果是在Node.js环境中训练的模型(比如用@tensorflow/tfjs-node@tensorflow/tfjs-node-gpu),可以直接保存到本地文件系统:

const tf = require('@tensorflow/tfjs-node');

// 假设model是训练好的模型
await model.save('file://./models/my-trained-model');

执行后会在指定路径下生成model.json和对应的.bin权重文件,和浏览器下载的结构一致。


二、加载已保存的模型

1. 浏览器环境下的加载

  • 加载本地下载的模型:需要把model.json.bin文件部署到服务器(或者用本地开发服务器,比如VS Code的Live Server插件),然后用tf.loadLayersModel方法加载:

    const loadedModel = await tf.loadLayersModel('http://your-domain.com/models/my-trained-model/model.json');
    

    注意:如果直接打开本地HTML文件加载本地模型,会遇到跨域问题,所以一定要用服务器托管模型文件。

  • 加载IndexedDB里的模型:直接用之前保存的indexeddb://路径即可,同一个域名下无需额外部署:

    const loadedModel = await tf.loadLayersModel('indexeddb://my-trained-model');
    

2. Node.js环境下的加载

同样用tf.loadLayersModel,指定本地文件路径即可:

const tf = require('@tensorflow/tfjs-node');

const loadedModel = await tf.loadLayersModel('file://./models/my-trained-model/model.json');

小提示

如果你的模型使用了自定义层、自定义损失函数或者自定义指标,在加载模型之前一定要先注册这些自定义组件,否则会加载失败。比如自定义层的注册:

class MyCustomLayer extends tf.layers.Layer {
  // 自定义层的实现
}

// 注册自定义层
tf.serialization.registerClass(MyCustomLayer);

// 之后再加载模型
const loadedModel = await tf.loadLayersModel(...);

内容的提问来源于stack exchange,提问作者TheUnreal

火山引擎 最新活动