You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

YOLOv5 ONNX模型Java部署问题:类别ID异常与边界框偏移

YOLOv5 ONNX模型Java部署问题:类别ID异常与边界框偏移

看起来你在把YOLOv5的ONNX模型迁移到Java时踩了两个典型的坑:所有检测结果的类别ID全是0,还有边界框位置偏移得离谱。我帮你拆解一下问题根源,直接给你可行的修复方案。


一、类别ID全为0的核心原因

先看你的ONNX模型输出维度:float32[1,25200,15]。这里的15个通道含义是:
[center_x, center_y, width, height, confidence, class0_score, class1_score, ..., class9_score]

  • 前5个值是检测框的位置和目标置信度
  • 后面10个是你10个类别的预测分数(0-1之间的浮点数)

你现在直接取row[5]转成int作为classID,这完全错了:row[5]只是第一个类别的预测分数,不是类别ID!因为分数是0-1的浮点数,转成int自然就是0,这就是为什么所有结果的classID都是0。

修复方案

遍历row[5]row[14]这10个类别分数,找到最大值对应的索引,这个索引就是真实的classID:

// 提取10个类别的预测分数
float[] classScores = new float[10];
System.arraycopy(row, 5, classScores, 0, 10);

// 找到最大分数对应的类别ID
int classId = 0;
float maxScore = classScores[0];
for (int i = 1; i < classScores.length; i++) {
    if (classScores[i] > maxScore) {
        maxScore = classScores[i];
        classId = i;
    }
}
// 这里的classId对应模型metadata里的0-9,对应类别'1'-'10'

二、边界框偏移/变形的问题

这个问题是两个关键错误叠加导致的:

1. ONNX输出坐标格式理解错误

YOLOv5的ONNX原始输出是中心点+宽高格式(center_x, center_y, width, height),但你当成了左上角+右下角x1,y1,x2,y2)来处理,这直接导致坐标计算完全混乱。

2. Java drawRect方法参数用错

Java的Graphics.drawRect(int x, int y, int width, int height)后两个参数是矩形的宽和高,不是右下角坐标。你现在直接传入x2和y2,相当于把右下角坐标当成了宽高,框肯定变形。

修复方案

先把中心点格式转成x1y1x2y2,再计算宽高用于绘图,最后处理缩放和padding:

public static void postProcess(float[][] detections, int origWidth, int origHeight, BufferedImage originalImage) {
    System.out.println("\nDetections:");

    int inputSize = 640;
    float gain = Math.min((float) inputSize / origWidth, (float) inputSize / origHeight);
    float padX = (inputSize - origWidth * gain) / 2;
    float padY = (inputSize - origHeight * gain) / 2;

    BufferedImage outputImage = new BufferedImage(origWidth, origHeight, BufferedImage.TYPE_INT_RGB);
    Graphics g = outputImage.getGraphics();
    g.drawImage(originalImage, 0, 0, null);
    g.setStroke(new BasicStroke(2)); // 让边界框更清晰

    for (float[] row : detections) {
        if (row.length < 15) continue;

        // 1. 提取ONNX原始输出的中心点+宽高格式
        float centerX = row[0];
        float centerY = row[1];
        float bboxWidth = row[2];
        float bboxHeight = row[3];
        float objConfidence = row[4];

        // 2. 转换为x1,y1,x2,y2格式
        float x1 = centerX - bboxWidth / 2;
        float y1 = centerY - bboxHeight / 2;
        float x2 = centerX + bboxWidth / 2;
        float y2 = centerY + bboxHeight / 2;

        // 3. 计算类别ID(修复全0问题)
        float[] classScores = new float[10];
        System.arraycopy(row, 5, classScores, 0, 10);
        int classId = 0;
        float maxClassScore = classScores[0];
        for (int i = 1; i < 10; i++) {
            if (classScores[i] > maxClassScore) {
                maxClassScore = classScores[i];
                classId = i;
            }
        }
        // YOLOv5的最终置信度是目标置信度×类别分数
        float finalConfidence = objConfidence * maxClassScore;

        if (finalConfidence > 0.5) {
            // 4. 修正坐标到原始图像尺寸
            x1 = (x1 - padX) / gain;
            y1 = (y1 - padY) / gain;
            x2 = (x2 - padX) / gain;
            y2 = (y2 - padY) / gain;

            // 裁剪到图像边界,避免超出范围
            x1 = Math.max(0, Math.min(x1, origWidth));
            y1 = Math.max(0, Math.min(y1, origHeight));
            x2 = Math.max(0, Math.min(x2, origWidth));
            y2 = Math.max(0, Math.min(y2, origHeight));

            // 5. 计算drawRect需要的宽高
            int rectX = (int) Math.round(x1);
            int rectY = (int) Math.round(y1);
            int rectWidth = (int) Math.round(x2 - x1);
            int rectHeight = (int) Math.round(y2 - y1);

            System.out.printf("BBox: [%.2f, %.2f, %.2f, %.2f], Confidence: %.2f, Class ID: %d%n",
                    x1, y1, x2, y2, finalConfidence, classId);

            // 绘制边界框
            g.setColor(Color.RED);
            g.drawRect(rectX, rectY, rectWidth, rectHeight);

            // 绘制标签
            g.setColor(Color.WHITE);
            g.drawString(String.format("ID: %d Conf: %.2f", classId, finalConfidence), rectX, rectY - 10);
        }
    }
    g.dispose();

    try {
        ImageIO.write(outputImage, "jpg", new File("output.jpg"));
        System.out.println("Output image saved as output.jpg");
    } catch (Exception e) {
        e.printStackTrace();
    }
}

额外优化建议

  1. 图像Resize优化:你现在的resize方法是直接拉伸图像,会导致图像变形。YOLOv5的预处理是保持比例,对空白区域填充RGB(114,114,114),你可以修改resizeImage方法实现这一点,避免因图像变形导致的检测偏差。
  2. 添加NMS:你提到还没实现NMS,这会导致大量重复的低置信度框,建议后续实现非极大值抑制,保留最优的检测框。
  3. 通道顺序确认:YOLOv5的PyTorch模型输入是RGB顺序,你的Java代码里把R、G、B分别放到tensor的0、1、2通道是对的,这部分没问题。

备注:内容来源于stack exchange,提问作者CoffeeCoding

火山引擎 最新活动