使用Supervision库为YOLOv8/YOLOv11计算混淆矩阵时的数据集配置与空矩阵问题排查
Supervision库为YOLOv8/YOLOv11计算混淆矩阵时的数据集配置与空矩阵问题排查
你遇到的问题核心是Supervision版本迭代后的接口变更,加上数据集路径/格式的细节没踩对,导致数据集加载异常、混淆矩阵为空。我来一步步给你梳理正确的解决流程:
一、先解决版本与接口不匹配的问题
你用的是supervision==0.26.1,这个版本的DetectionDataset接口已经和旧教程/笔记不一致了——原来的dataset.images属性被移除,换成了迭代器方式获取样本。比如要查看数据集样本,应该用next(iter(dataset))而非dataset.images.keys()。
更推荐升级到稳定兼容版本:配对supervision>=0.28.0和ultralytics>=8.0.200,能避免大部分兼容性问题;如果不想升级,要严格对应旧版本的调用逻辑。
二、正确配置YOLO格式数据集加载
你的from_yolo方法参数方向是对的,但要注意几个关键细节(这是最容易导致数据集加载失败的点):
- 路径必须精准对应验证集
images_directory_path:必须是验证集图片的独立根目录(比如./dataset/val/images),不能用整个数据集的图片根目录,否则会混入训练集样本annotations_directory_path:必须是验证集标注的独立根目录(比如./dataset/val/labels),每个图片必须对应同名的YOLO格式.txt标注(即使无目标也要有空的.txt,否则样本会被跳过)data_yaml_path:数据集YAML文件,里面要正确包含names(类别列表)和nc(类别数),YAML里的数据集路径可忽略,因为我们已手动指定了验证集路径
- 验证数据集是否真的加载成功
加载后用以下方式调试,替代旧的dataset.images:# 查看类别是否正确匹配 print("Dataset classes:", dataset.classes) # 查看验证集样本总数(如果为0,说明路径完全错误) print("Total validation samples:", len(dataset)) # 取出第一个样本验证标注是否加载 sample_image, sample_annotations = next(iter(dataset)) print("Sample image shape:", sample_image.shape) print("Sample annotation count:", len(sample_annotations.class_id))
三、完整可运行的混淆矩阵计算代码
结合你的场景,给你适配supervision>=0.26.0的完整代码:
import supervision as sv from ultralytics import YOLO import numpy as np # 替换为你的实际路径 IMAGES_VAL_DIR = "./dataset/val/images" # 验证集图片根目录 ANNOTS_VAL_DIR = "./dataset/val/labels" # 验证集标注根目录 YAML_PATH = "./dataset/data.yaml" # 数据集类别YAML文件 MODEL_PATH = "./runs/detect/train/weights/best.pt" # 训练好的模型权重 # 1. 加载验证集 dataset = sv.DetectionDataset.from_yolo( images_directory_path=IMAGES_VAL_DIR, annotations_directory_path=ANNOTS_VAL_DIR, data_yaml_path=YAML_PATH ) # 调试:确认数据集加载正常 print(f"Loaded {len(dataset)} validation samples") print(f"Class list: {dataset.classes}") # 2. 加载YOLO模型 model = YOLO(MODEL_PATH) # 3. 定义预测回调函数(适配Supervision要求) def callback(image: np.ndarray) -> sv.Detections: # 保持与训练一致的预处理参数 results = model.predict( source=image, imgsz=640, # 替换为你训练时用的尺寸 conf=0.25, # 适当降低阈值避免漏检 verbose=False # 关闭预测日志,避免刷屏 ) return sv.Detections.from_ultralytics(results[0]) # 4. 计算混淆矩阵 confusion_matrix = sv.ConfusionMatrix.benchmark( dataset=dataset, callback=callback, class_names=dataset.classes # 显式传入类别,确保映射正确 ) # 5. 可视化并保存结果 confusion_matrix.plot( class_names=dataset.classes, title="YOLOv8 Validation Confusion Matrix", save_path="./confusion_matrix.png" )
四、空混淆矩阵的常见排查点
如果还是出现空矩阵,按以下顺序排查:
- 文件名一致性检查:图片与标注文件名必须完全一致(包括大小写,Linux/macOS区分大小写),比如
person_001.jpg对应person_001.txt - 标注格式正确性:YOLO标注每一行是
class_id x_center y_center width height,所有坐标必须是归一化后的值(0-1之间),绝对像素值会导致加载失败 - 模型预测有效性:在
callback函数中加调试代码,确认模型有输出:
如果输出全为0,说明模型在验证集上无检测结果,可能是权重错误、置信度过高,或验证集与训练集分布差异过大def callback(image: np.ndarray) -> sv.Detections: results = model.predict(source=image, conf=0.25, verbose=False) print(f"Detected {len(results[0].boxes)} objects in sample") return sv.Detections.from_ultralytics(results[0]) - 类别ID顺序一致性:YAML文件的
names顺序必须与模型训练时的YAML完全一致,否则类别映射会出错,导致混淆矩阵显示异常
五、旧版本Supervision(0.26.1)适配方案
如果不想升级Supervision,需调整数据集遍历和混淆矩阵生成逻辑:
# 旧版本查看样本的方式 for sample in dataset: image, annotations = sample print(image.shape, annotations.class_id) break # 旧版本手动计算混淆矩阵 confusion_matrix = sv.ConfusionMatrix( num_classes=len(dataset.classes), class_names=dataset.classes ) # 手动遍历数据集更新矩阵 for image, annotations in dataset: predictions = callback(image) confusion_matrix.update( annotations.class_id, predictions.class_id ) confusion_matrix.plot()
按以上步骤一步步调试,先确保数据集能正确加载出样本和标注,再跑混淆矩阵,就能解决空矩阵的问题了。




