You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

无法从sklearn.metrics导入plot_confusion_matrix的技术求助

解决from sklearn.metrics import plot_confusion_matrix的ImportError问题

你遇到这个导入错误的核心原因很明确:scikit-learn 1.2及更高版本已经移除了plot_confusion_matrix函数,官方用ConfusionMatrixDisplay类替代了它的功能。下面给你两种可行的解决办法:

方法一:迁移到新的API(推荐)

新的ConfusionMatrixDisplay提供了更灵活的可视化方式,用法和旧函数很接近,主要有两种调用方式:

直接通过模型生成混淆矩阵可视化

from sklearn.metrics import ConfusionMatrixDisplay

# 假设你已经训练好模型model,有测试集X_test和y_test
ConfusionMatrixDisplay.from_estimator(model, X_test, y_test)

先计算混淆矩阵再可视化

如果需要先自定义计算混淆矩阵,再进行可视化:

from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
import matplotlib.pyplot as plt

# 计算混淆矩阵
y_pred = model.predict(X_test)
cm = confusion_matrix(y_test, y_pred)

# 初始化并绘制
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.show()

方法二:降级scikit-learn到旧版本(不推荐)

如果你暂时不想修改代码,可以把scikit-learn降级到1.1.x版本(比如1.1.3,这是最后一个包含plot_confusion_matrix的稳定版本):

pip install scikit-learn==1.1.3

不过这个方法不建议长期使用,因为旧版本可能存在安全隐患,而且会错过scikit-learn后续的功能更新。

额外小提示

你可以通过以下代码确认自己当前的scikit-learn版本:

import sklearn
print(sklearn.__version__)

内容的提问来源于stack exchange,提问作者excelman

火山引擎 最新活动