YOLOv11导出TensorFlow后输出张量形状解读及后处理代码正确性验证
YOLOv11导出TensorFlow后输出张量形状解读及后处理代码正确性验证
我来帮你理清YOLOv11导出TensorFlow后的输出逻辑,以及修正你的后处理代码问题。
首先,你导出模型时看到的输出信息如下:
'yolo11n.pt' with input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 84, 8400) (5.4 MB)
对应的Keras模型结构总结:
Model: "functional_1" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ input_layer_4 (InputLayer) │ (None, 640, 640, 3) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ tfsm_layer_8 (TFSMLayer) │ (1, 84, 8400) │ 0 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘ Total params: 0 (0.00 B) Trainable params: 0 (0.00 B) Non-trainable params: 0 (0.00 B)
这个输出形状(1, 84, 8400)的含义和YOLOv8是一脉相承的,参考相关说明:
The output shapes for YOLOv8n and YOLOv8n-seg models represent different components. For YOLOv8n, the shape (1, 84, 8400) includes 80 classes and 4 bounding box parameters. For YOLOv8n-seg, the first output (1, 116, 8400) includes 80 classes, 4 parameters, and 32 mask coefficients, while the second output (1, 32, 160, 160) represents the prototype masks.
简单拆解一下:
1:推理的批次大小(你这里是单张图片)84:4个边界框参数(x中心、y中心、宽、高) + 80个COCO类别的置信度得分8400:640x640输入下,三个检测尺度拼接后的总锚框数量
你的后处理代码问题分析
你当前的代码存在一个核心错误:YOLOv11的输出里没有单独的"通用置信度"通道,你错误地把第5个值当成了置信度,但实际上第5到第84位是80个类别的专属置信度得分。而且代码没有做非极大值抑制(NMS),导致出现大量重复框。
修正后的后处理代码
import numpy as np import cv2 import matplotlib.pyplot as plt # 假设模型输出为 (1, 84, 8400),先去除批次维度 output = model.predict(image)[0] # 形状变为 (84, 8400) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 统一为RGB格式用于显示 # 1. 分离边界框参数与类别置信度 boxes_raw = output[:4, :].T # (8400, 4) → 格式为(x_center, y_center, width, height) class_scores = output[4:, :].T # (8400, 80) → 每个锚框对应80个类别的得分 # 2. 提取每个锚框的最大类别得分及对应类别ID max_scores = np.max(class_scores, axis=1) # (8400,) class_ids = np.argmax(class_scores, axis=1) # (8400,) # 3. 过滤低得分锚框 conf_threshold = 0.25 valid_indices = np.where(max_scores > conf_threshold)[0] filtered_boxes = boxes_raw[valid_indices] filtered_scores = max_scores[valid_indices] filtered_classes = class_ids[valid_indices] # 4. 转换框格式:(x_center, y_center, w, h) → (x1, y1, x2, y2) def xywh2xyxy(x): y = np.copy(x) y[:, 0] = x[:, 0] - x[:, 2] / 2 # x1 = 中心x - 宽度/2 y[:, 1] = x[:, 1] - x[:, 3] / 2 # y1 = 中心y - 高度/2 y[:, 2] = x[:, 0] + x[:, 2] / 2 # x2 = 中心x + 宽度/2 y[:, 3] = x[:, 1] + x[:, 3] / 2 # y2 = 中心y + 高度/2 return y filtered_boxes = xywh2xyxy(filtered_boxes) # 5. 执行非极大值抑制(NMS)去除重复框 nms_threshold = 0.5 keep_indices = cv2.dnn.NMSBoxes( filtered_boxes.tolist(), filtered_scores.tolist(), score_threshold=conf_threshold, nms_threshold=nms_threshold )[0] # 6. 筛选最终的检测结果 final_boxes = filtered_boxes[keep_indices] final_scores = filtered_scores[keep_indices] final_classes = filtered_classes[keep_indices] # 7. 绘制检测结果(使用COCO类别名称示例) class_names = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] for box, score, cls_id in zip(final_boxes, final_scores, final_classes): x1, y1, x2, y2 = map(int, box) # 确保坐标在图像范围内 x1 = max(0, x1) y1 = max(0, y1) x2 = min(image.shape[1], x2) y2 = min(image.shape[0], y2) # 绘制框和标签 cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) label = f"{class_names[cls_id]}: {score:.2f}" cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) # 显示结果图像 plt.figure(figsize=(10, 6)) plt.imshow(image) plt.axis("off") plt.show()
你原始代码的结果问题
你之前得到的满屏红色框是因为:
- 错误提取了第一个类别的得分作为通用置信度,导致过滤逻辑失效
- 没有执行NMS,所有符合低阈值的锚框都被绘制出来,出现大量重复检测
修正后的代码会先过滤低得分锚框,再通过NMS去除重复框,得到精准的检测结果。

备注:内容来源于stack exchange,提问作者Muhammad Ikhwan Perwira




