You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

关于Scikit-learn中Precision-Recall Curve计算的困惑

理解Scikit-learn中的Precision-Recall Curve计算

咱们直接结合Scikit-learn官方给出的例子,一步步拆解precision_recall_curve函数的计算逻辑,这样你就能明白输出结果是怎么来的了。

首先先看官方的代码片段:

import numpy as np
from sklearn.metrics import precision_recall_curve

y_true = np.array([0, 0, 1, 1])  # 真实标签,2个负类,2个正类
y_scores = np.array([0.1, 0.4, 0.35, 0.8])  # 模型输出的得分/概率

precision, recall, thresholds = precision_recall_curve(y_true, y_scores)

print("precision:", precision)  # 输出: [0.66666667 0.5        1.         1.        ]
print("recall:", recall)        # 输出: [1.  0.5 0.5 0. ]
print("thresholds:", thresholds)# 输出: [0.35 0.4  0.8 ]

关键概念回顾

  • 精确率(Precision)TP / (TP + FP),即预测为正类的样本中真正是正类的比例
  • 召回率(Recall)TP / (TP + FN),即真实正类中被正确预测为正类的比例
  • 阈值(Threshold):判断样本是否为正类的得分临界值,得分≥阈值则预测为正类

逐个分析阈值对应的结果

函数会从y_scores中提取唯一值并排序,作为候选阈值,最后额外补充一个终点让曲线完整:

  1. 阈值=0.35

    • 所有得分≥0.35的样本:索引1(0.4)、2(0.35)、3(0.8)
    • 真实标签对应:0、1、1
    • TP=2(真实正类且预测正类),FP=1(真实负类但预测正类)
    • Precision = 2/(2+1) ≈ 0.666...,Recall = 2/2 = 1.0
  2. 阈值=0.4

    • 所有得分≥0.4的样本:索引1(0.4)、3(0.8)
    • 真实标签对应:0、1
    • TP=1,FP=1
    • Precision = 1/(1+1) = 0.5,Recall = 1/2 = 0.5
  3. 阈值=0.8

    • 所有得分≥0.8的样本:索引3(0.8)
    • 真实标签对应:1
    • TP=1,FP=0
    • Precision = 1/(1+0) = 1.0,Recall = 1/2 = 0.5

最后一组(precision=1.0, recall=0.0)的由来

这是函数自动添加的曲线终点:当阈值设置为比所有样本得分都高时,没有样本会被预测为正类。此时TP=0,FP=0,数学上精确率为0/0无意义,Scikit-learn默认将其设为1.0;而召回率=0/2=0.0,这样曲线就能从(recall=1, precision≈0.66)完整延伸到(recall=0, precision=1.0)。

内容的提问来源于stack exchange,提问作者zeal

火山引擎 最新活动