You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何绘制多元线性回归的3D平面?现有实现方法存在问题

解决TensorFlow多元线性回归3D平面可视化问题

我看了你代码里的问题,主要是可视化逻辑走偏了,还有训练循环里的绘图操作完全没必要。咱们一步步来修正:

先说说你代码里的核心问题:

  • 你在500次训练循环里每次都调用plt.plot,这会在图上画500条杂乱的线,完全不是你要的3D平面
  • 尝试plot_surface时,你直接传入了真实标签y_data,但这个方法需要的是模型预测的平面值,而且输入必须是网格状数据,不是原始散点
  • 你的hypothesis写法可以简化,用TensorFlow的矩阵乘法更规范,也不容易出错

修正后的完整代码

import tensorflow as tf
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from pandas.io.parsers import read_csv

# 初始化3D绘图对象
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# 读取并预处理数据
data = read_csv('price data2.csv', sep=',')
xy = np.array(data, dtype=np.float32)
x_data = xy[0:500, 1:-1]  # 提取两个输入特征
y_data = xy[0:500, [-1]]  # 提取标签值

# TensorFlow模型定义
X = tf.placeholder(tf.float32, shape=[None, 2])
Y = tf.placeholder(tf.float32, shape=[None, 1])
W = tf.Variable(tf.random_normal([2, 1]), name="weight")
b = tf.Variable(tf.random_normal([1]), name="bias")
# 用矩阵乘法简化假设函数,更符合TensorFlow的运算逻辑
hypothesis = tf.matmul(X, W) + b

# 损失函数与优化器
cost = tf.reduce_mean(tf.square(hypothesis - Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.0001)
train = optimizer.minimize(cost)

# 初始化会话并训练模型
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# 训练过程:不要在循环里绘图!
for step in range(5000):  # 适当增加训练步数,提升拟合效果
    cost_, _ = sess.run([cost, train], feed_dict={X: x_data, Y: y_data})
    if step % 500 == 0:
        print(f"Step {step}, Cost: {cost_:.4f}")

# 训练完成后,准备可视化数据
# 生成覆盖特征取值范围的网格数据
x0_range = np.linspace(np.min(x_data[:, 0]), np.max(x_data[:, 0]), 100)
x1_range = np.linspace(np.min(x_data[:, 1]), np.max(x_data[:, 1]), 100)
x0, x1 = np.meshgrid(x0_range, x1_range)

# 计算网格点对应的模型预测值
X_grid = np.concatenate([x0.reshape(-1,1), x1.reshape(-1,1)], axis=1)
y_pred = sess.run(hypothesis, feed_dict={X: X_grid})
y_pred = y_pred.reshape(x0.shape)  # 转换成和网格匹配的形状

# 绘制原始数据散点
ax.scatter(x_data[:, 0], x_data[:, 1], y_data, c='r', marker='o', label='Raw Data')
# 绘制回归平面
ax.plot_surface(x0, x1, y_pred, alpha=0.5, cmap='viridis', label='Regression Plane')

# 设置坐标轴标签与图例
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_zlabel('Price')
ax.legend()

plt.show()
sess.close()

关键修改点说明:

  1. 移除循环内的绘图:训练过程中不需要实时绘图,等模型训练完成后再一次性绘制平面和散点
  2. 生成网格数据:用linspacemeshgrid生成覆盖特征范围的网格,这是plot_surface能画出连续平面的前提
  3. 计算预测平面值:把网格数据传入训练好的模型,得到每个网格点的预测值,再转换成网格形状用于绘图
  4. 优化模型定义:用tf.matmul(X, W) + b代替手动索引计算,更简洁也避免索引错误
  5. 增加训练步数:500步可能不足以拟合数据,改成5000步并定期打印损失值,方便观察训练进度

这样你就能看到清晰的3D回归平面和原始数据散点的对比效果了!

内容的提问来源于stack exchange,提问作者Tae

火山引擎 最新活动