基于TensorFlow.js修改官网基础示例,实现电影类型输入的喜好预测
用TensorFlow.js实现电影类型偏好预测的改造方案
我帮你把TensorFlow.js的基础线性回归示例,改成能输入电影类型数组、预测你是否喜欢的分类模型,一步步来:
1. 调整模型结构适配多特征输入
原来的模型是单输入的线性回归,现在我们要处理电影类型数组这种多特征输入,而且是二分类任务(喜欢/不喜欢),所以得改模型结构:
// 定义适配多特征的分类模型 const model = tf.sequential(); // 输入维度对应你的电影类型数量,比如你选3种类型就填[3] model.add(tf.layers.dense({units: 8, activation: 'relu', inputShape: [3]})); // 输出层用sigmoid激活,输出0-1之间的概率值(代表喜欢的概率) model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
2. 准备贴合需求的训练数据
你需要把电影类型转换成数值数组(比如用one-hot编码:某类存在填1,否则填0),再配上你真实的偏好标签(1=喜欢,0=不喜欢),示例如下:
// 训练特征:每一行对应一组电影类型组合 const xs = tf.tensor2d([ [1, 0, 0], // 纯动作片 [0, 1, 0], // 纯喜剧片 [0, 0, 1], // 纯科幻片 [1, 1, 0], // 动作+喜剧 [1, 0, 1], // 动作+科幻 [0, 1, 1], // 喜剧+科幻 [1, 1, 1] // 全类型混合 ]); // 标签:对应上面每组类型的偏好,0-1之间的数值都可以(越接近1越喜欢) const ys = tf.tensor2d([ [1], [0], [1], [0.8], [0.9], [0.3], [0.7] ]);
3. 编译并训练模型
因为是二分类任务,要把损失函数换成binaryCrossentropy,优化器用adam会比原示例的sgd更适配,还可以加准确率指标方便看训练效果:
// 编译模型,适配二分类任务 model.compile({ loss: 'binaryCrossentropy', optimizer: 'adam', metrics: ['accuracy'] // 实时查看训练准确率 }); // 异步训练函数,避免阻塞主线程 async function trainModel() { const history = await model.fit(xs, ys, { epochs: 500, // 训练轮次,可根据效果调整 batchSize: 2, callbacks: { onEpochEnd: (epoch, logs) => { console.log(`第${epoch+1}轮训练:损失值=${logs.loss.toFixed(4)},准确率=${logs.accuracy.toFixed(4)}`); } } }); console.log('训练完成!'); } // 启动训练 trainModel();
4. 输入电影类型数组做预测
写个预测函数,输入你指定的电影类型数组,就能得到喜欢的概率和结果:
// 预测函数:输入电影类型数组,返回偏好结果 async function predictPreference(movieTypes) { // 将输入转为TensorFlow.js能处理的张量 const input = tf.tensor2d([movieTypes]); // 预测喜欢的概率 const prediction = await model.predict(input).data(); // 判断是否喜欢:概率>0.5就认为喜欢 const isLiked = prediction[0] > 0.5; console.log(`输入的电影类型数组:${movieTypes}`); console.log(`喜欢的概率:${(prediction[0]*100).toFixed(2)}%`); console.log(`预测结果:${isLiked ? '喜欢' : '不喜欢'}`); // 手动清理张量,避免内存泄漏 input.dispose(); } // 示例:预测「动作+科幻」类型的偏好 predictPreference([1, 0, 1]);
小提示
- 你可以根据自己的需求扩展电影类型数量,只要同步修改
inputShape的数值就行 - 训练数据越多、标签越贴合你的真实偏好,模型预测就越准
- 如果遇到内存问题,记得用
tf.tidy()包裹计算逻辑,或者手动调用dispose()清理不用的张量
内容的提问来源于stack exchange,提问作者mohas




