You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

基于TensorFlow.js修改官网基础示例,实现电影类型输入的喜好预测

用TensorFlow.js实现电影类型偏好预测的改造方案

我帮你把TensorFlow.js的基础线性回归示例,改成能输入电影类型数组、预测你是否喜欢的分类模型,一步步来:

1. 调整模型结构适配多特征输入

原来的模型是单输入的线性回归,现在我们要处理电影类型数组这种多特征输入,而且是二分类任务(喜欢/不喜欢),所以得改模型结构:

// 定义适配多特征的分类模型
const model = tf.sequential();
// 输入维度对应你的电影类型数量,比如你选3种类型就填[3]
model.add(tf.layers.dense({units: 8, activation: 'relu', inputShape: [3]}));
// 输出层用sigmoid激活,输出0-1之间的概率值(代表喜欢的概率)
model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));

2. 准备贴合需求的训练数据

你需要把电影类型转换成数值数组(比如用one-hot编码:某类存在填1,否则填0),再配上你真实的偏好标签(1=喜欢,0=不喜欢),示例如下:

// 训练特征:每一行对应一组电影类型组合
const xs = tf.tensor2d([
  [1, 0, 0], // 纯动作片
  [0, 1, 0], // 纯喜剧片
  [0, 0, 1], // 纯科幻片
  [1, 1, 0], // 动作+喜剧
  [1, 0, 1], // 动作+科幻
  [0, 1, 1], // 喜剧+科幻
  [1, 1, 1]  // 全类型混合
]);

// 标签:对应上面每组类型的偏好,0-1之间的数值都可以(越接近1越喜欢)
const ys = tf.tensor2d([
  [1], [0], [1], [0.8], [0.9], [0.3], [0.7]
]);

3. 编译并训练模型

因为是二分类任务,要把损失函数换成binaryCrossentropy,优化器用adam会比原示例的sgd更适配,还可以加准确率指标方便看训练效果:

// 编译模型,适配二分类任务
model.compile({
  loss: 'binaryCrossentropy',
  optimizer: 'adam',
  metrics: ['accuracy'] // 实时查看训练准确率
});

// 异步训练函数,避免阻塞主线程
async function trainModel() {
  const history = await model.fit(xs, ys, {
    epochs: 500, // 训练轮次,可根据效果调整
    batchSize: 2,
    callbacks: {
      onEpochEnd: (epoch, logs) => {
        console.log(`第${epoch+1}轮训练:损失值=${logs.loss.toFixed(4)},准确率=${logs.accuracy.toFixed(4)}`);
      }
    }
  });
  console.log('训练完成!');
}

// 启动训练
trainModel();

4. 输入电影类型数组做预测

写个预测函数,输入你指定的电影类型数组,就能得到喜欢的概率和结果:

// 预测函数:输入电影类型数组,返回偏好结果
async function predictPreference(movieTypes) {
  // 将输入转为TensorFlow.js能处理的张量
  const input = tf.tensor2d([movieTypes]);
  // 预测喜欢的概率
  const prediction = await model.predict(input).data();
  // 判断是否喜欢:概率>0.5就认为喜欢
  const isLiked = prediction[0] > 0.5;
  
  console.log(`输入的电影类型数组:${movieTypes}`);
  console.log(`喜欢的概率:${(prediction[0]*100).toFixed(2)}%`);
  console.log(`预测结果:${isLiked ? '喜欢' : '不喜欢'}`);
  
  // 手动清理张量,避免内存泄漏
  input.dispose();
}

// 示例:预测「动作+科幻」类型的偏好
predictPreference([1, 0, 1]);

小提示

  • 你可以根据自己的需求扩展电影类型数量,只要同步修改inputShape的数值就行
  • 训练数据越多、标签越贴合你的真实偏好,模型预测就越准
  • 如果遇到内存问题,记得用tf.tidy()包裹计算逻辑,或者手动调用dispose()清理不用的张量

内容的提问来源于stack exchange,提问作者mohas

火山引擎 最新活动