You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

使用Supervision库为YOLOv8/YOLOv11计算混淆矩阵时的数据集配置与空矩阵问题排查

Supervision库为YOLOv8/YOLOv11计算混淆矩阵时的数据集配置与空矩阵问题排查

你遇到的问题核心是Supervision版本迭代后的接口变更,加上数据集路径/格式的细节没踩对,导致数据集加载异常、混淆矩阵为空。我来一步步给你梳理正确的解决流程:

一、先解决版本与接口不匹配的问题

你用的是supervision==0.26.1,这个版本的DetectionDataset接口已经和旧教程/笔记不一致了——原来的dataset.images属性被移除,换成了迭代器方式获取样本。比如要查看数据集样本,应该用next(iter(dataset))而非dataset.images.keys()

更推荐升级到稳定兼容版本:配对supervision>=0.28.0ultralytics>=8.0.200,能避免大部分兼容性问题;如果不想升级,要严格对应旧版本的调用逻辑。

二、正确配置YOLO格式数据集加载

你的from_yolo方法参数方向是对的,但要注意几个关键细节(这是最容易导致数据集加载失败的点):

  1. 路径必须精准对应验证集
    • images_directory_path:必须是验证集图片的独立根目录(比如./dataset/val/images),不能用整个数据集的图片根目录,否则会混入训练集样本
    • annotations_directory_path:必须是验证集标注的独立根目录(比如./dataset/val/labels),每个图片必须对应同名的YOLO格式.txt标注(即使无目标也要有空的.txt,否则样本会被跳过)
    • data_yaml_path:数据集YAML文件,里面要正确包含names(类别列表)和nc(类别数),YAML里的数据集路径可忽略,因为我们已手动指定了验证集路径
  2. 验证数据集是否真的加载成功
    加载后用以下方式调试,替代旧的dataset.images
    # 查看类别是否正确匹配
    print("Dataset classes:", dataset.classes)
    # 查看验证集样本总数(如果为0,说明路径完全错误)
    print("Total validation samples:", len(dataset))
    # 取出第一个样本验证标注是否加载
    sample_image, sample_annotations = next(iter(dataset))
    print("Sample image shape:", sample_image.shape)
    print("Sample annotation count:", len(sample_annotations.class_id))
    

三、完整可运行的混淆矩阵计算代码

结合你的场景,给你适配supervision>=0.26.0的完整代码:

import supervision as sv
from ultralytics import YOLO
import numpy as np

# 替换为你的实际路径
IMAGES_VAL_DIR = "./dataset/val/images"   # 验证集图片根目录
ANNOTS_VAL_DIR = "./dataset/val/labels"    # 验证集标注根目录
YAML_PATH = "./dataset/data.yaml"          # 数据集类别YAML文件
MODEL_PATH = "./runs/detect/train/weights/best.pt"  # 训练好的模型权重

# 1. 加载验证集
dataset = sv.DetectionDataset.from_yolo(
    images_directory_path=IMAGES_VAL_DIR,
    annotations_directory_path=ANNOTS_VAL_DIR,
    data_yaml_path=YAML_PATH
)

# 调试:确认数据集加载正常
print(f"Loaded {len(dataset)} validation samples")
print(f"Class list: {dataset.classes}")

# 2. 加载YOLO模型
model = YOLO(MODEL_PATH)

# 3. 定义预测回调函数(适配Supervision要求)
def callback(image: np.ndarray) -> sv.Detections:
    # 保持与训练一致的预处理参数
    results = model.predict(
        source=image,
        imgsz=640,  # 替换为你训练时用的尺寸
        conf=0.25,  # 适当降低阈值避免漏检
        verbose=False  # 关闭预测日志,避免刷屏
    )
    return sv.Detections.from_ultralytics(results[0])

# 4. 计算混淆矩阵
confusion_matrix = sv.ConfusionMatrix.benchmark(
    dataset=dataset,
    callback=callback,
    class_names=dataset.classes  # 显式传入类别,确保映射正确
)

# 5. 可视化并保存结果
confusion_matrix.plot(
    class_names=dataset.classes,
    title="YOLOv8 Validation Confusion Matrix",
    save_path="./confusion_matrix.png"
)

四、空混淆矩阵的常见排查点

如果还是出现空矩阵,按以下顺序排查:

  1. 文件名一致性检查:图片与标注文件名必须完全一致(包括大小写,Linux/macOS区分大小写),比如person_001.jpg对应person_001.txt
  2. 标注格式正确性:YOLO标注每一行是class_id x_center y_center width height,所有坐标必须是归一化后的值(0-1之间),绝对像素值会导致加载失败
  3. 模型预测有效性:在callback函数中加调试代码,确认模型有输出:
    def callback(image: np.ndarray) -> sv.Detections:
        results = model.predict(source=image, conf=0.25, verbose=False)
        print(f"Detected {len(results[0].boxes)} objects in sample")
        return sv.Detections.from_ultralytics(results[0])
    
    如果输出全为0,说明模型在验证集上无检测结果,可能是权重错误、置信度过高,或验证集与训练集分布差异过大
  4. 类别ID顺序一致性:YAML文件的names顺序必须与模型训练时的YAML完全一致,否则类别映射会出错,导致混淆矩阵显示异常

五、旧版本Supervision(0.26.1)适配方案

如果不想升级Supervision,需调整数据集遍历和混淆矩阵生成逻辑:

# 旧版本查看样本的方式
for sample in dataset:
    image, annotations = sample
    print(image.shape, annotations.class_id)
    break

# 旧版本手动计算混淆矩阵
confusion_matrix = sv.ConfusionMatrix(
    num_classes=len(dataset.classes),
    class_names=dataset.classes
)

# 手动遍历数据集更新矩阵
for image, annotations in dataset:
    predictions = callback(image)
    confusion_matrix.update(
        annotations.class_id,
        predictions.class_id
    )

confusion_matrix.plot()

按以上步骤一步步调试,先确保数据集能正确加载出样本和标注,再跑混淆矩阵,就能解决空矩阵的问题了。

火山引擎 最新活动