关于Scikit-learn中Precision-Recall Curve计算的困惑
理解Scikit-learn中的Precision-Recall Curve计算
咱们直接结合Scikit-learn官方给出的例子,一步步拆解precision_recall_curve函数的计算逻辑,这样你就能明白输出结果是怎么来的了。
首先先看官方的代码片段:
import numpy as np from sklearn.metrics import precision_recall_curve y_true = np.array([0, 0, 1, 1]) # 真实标签,2个负类,2个正类 y_scores = np.array([0.1, 0.4, 0.35, 0.8]) # 模型输出的得分/概率 precision, recall, thresholds = precision_recall_curve(y_true, y_scores) print("precision:", precision) # 输出: [0.66666667 0.5 1. 1. ] print("recall:", recall) # 输出: [1. 0.5 0.5 0. ] print("thresholds:", thresholds)# 输出: [0.35 0.4 0.8 ]
关键概念回顾
- 精确率(Precision):
TP / (TP + FP),即预测为正类的样本中真正是正类的比例 - 召回率(Recall):
TP / (TP + FN),即真实正类中被正确预测为正类的比例 - 阈值(Threshold):判断样本是否为正类的得分临界值,得分≥阈值则预测为正类
逐个分析阈值对应的结果
函数会从y_scores中提取唯一值并排序,作为候选阈值,最后额外补充一个终点让曲线完整:
阈值=0.35
- 所有得分≥0.35的样本:索引1(0.4)、2(0.35)、3(0.8)
- 真实标签对应:0、1、1
- TP=2(真实正类且预测正类),FP=1(真实负类但预测正类)
- Precision = 2/(2+1) ≈ 0.666...,Recall = 2/2 = 1.0
阈值=0.4
- 所有得分≥0.4的样本:索引1(0.4)、3(0.8)
- 真实标签对应:0、1
- TP=1,FP=1
- Precision = 1/(1+1) = 0.5,Recall = 1/2 = 0.5
阈值=0.8
- 所有得分≥0.8的样本:索引3(0.8)
- 真实标签对应:1
- TP=1,FP=0
- Precision = 1/(1+0) = 1.0,Recall = 1/2 = 0.5
最后一组(precision=1.0, recall=0.0)的由来
这是函数自动添加的曲线终点:当阈值设置为比所有样本得分都高时,没有样本会被预测为正类。此时TP=0,FP=0,数学上精确率为0/0无意义,Scikit-learn默认将其设为1.0;而召回率=0/2=0.0,这样曲线就能从(recall=1, precision≈0.66)完整延伸到(recall=0, precision=1.0)。
内容的提问来源于stack exchange,提问作者zeal




