Android中使用mobilenet_v1_1.0_224.tflite如何获取目标检测边界框?
解决Android TFLite获取边界框的问题
首先,你说得没错——当前的代码确实无法返回边界框信息,核心原因是你正在使用的是图像分类TFLite模型,这类模型的设计目标就是只输出图像(或图像区域)的类别概率,完全没有边界框相关的输出张量。要获取目标的边界框,你需要更换为目标检测TFLite模型,下面是具体的步骤和代码调整建议:
1. 明确模型类型差异
- 图像分类模型:输入一张图像,输出每个类别的概率值(就是你当前代码里的
FloatArray),只回答"这是什么"。 - 目标检测模型:输入一张图像,输出多个检测结果,每个结果包含:边界框坐标(xmin, ymin, xmax, ymax)、类别ID、置信度(该检测结果的可靠程度),能回答"这是什么,在哪里"。
2. 更换为目标检测TFLite模型
你需要选择一个适合移动端的目标检测TFLite模型,比如:
- SSD-MobileNet系列(轻量,适合移动端)
- EfficientDet-Lite系列(精度和速度平衡)
- YOLOv5/YOLOv8的TFLite版本(需要自己导出或找预编译版本)
注意:下载模型时要确保是TFLite格式(.tflite),同时配套对应的labels.txt(类别列表要和模型训练时的一致)。
3. 修改代码适配目标检测模型
目标检测模型的输入输出结构和分类模型完全不同,需要调整你的代码逻辑:
3.1 调整模型初始化与输入参数
首先,你需要确认新模型的输入尺寸(比如SSD-MobileNet通常是300x300,EfficientDet-Lite0是320x320),替换原来的inputImageWidth和inputImageHeight。
3.2 修改检测逻辑(替换原classify函数)
目标检测模型的输出张量结构因模型而异,举个常见的例子(以SSD-MobileNet V2为例):
- 输出通常包含四个张量:边界框坐标、类别概率、置信度、实际检测目标数。
下面是适配这类模型的示例代码:
// 定义数据类存储检测结果 data class DetectionResult( val label: String, val confidence: Float, val boundingBox: RectF // 存储边界框的绝对坐标 ) private fun detectObjects(bitmap: Bitmap): String { check(isInitialized) { "TF Lite Interpreter is not initialized yet." } // 调整输入图像尺寸为模型要求的大小 val resizedImage = Bitmap.createScaledBitmap(bitmap, inputImageWidth, inputImageHeight, true) val byteBuffer = convertBitmapToByteBuffer(resizedImage) // 根据模型输出结构定义输出张量,这里以SSD-MobileNet V2为例 val outputLocations = Array(1) { Array(10) { FloatArray(4) } } // 边界框坐标 val outputClasses = Array(1) { Array(10) { FloatArray(91) } } // 类别概率 val outputScores = Array(1) { FloatArray(10) } // 置信度 val numDetections = FloatArray(1) // 实际检测到的目标数 val startTime = SystemClock.uptimeMillis() // 运行模型,注意输出张量的顺序要和模型定义一致 interpreter?.runForMultipleInputsOutputs( arrayOf(byteBuffer), mapOf( 0 to outputLocations, 1 to outputClasses, 2 to outputScores, 3 to numDetections ) ) val endTime = SystemClock.uptimeMillis() val inferenceTime = endTime - startTime // 解析检测结果 val results = mutableListOf<DetectionResult>() val numDetected = numDetections[0].toInt() for (i in 0 until numDetected) { val confidence = outputScores[0][i] // 过滤低置信度的结果,比如只保留置信度>0.5的 if (confidence < 0.5) continue // 获取类别ID val classIndex = outputClasses[0][i].indexOfFirst { it == outputClasses[0][i].max() } val label = labels.getOrElse(classIndex) { "Unknown" } // 将相对坐标转换为原始图像的绝对坐标 val xmin = outputLocations[0][i][0] * bitmap.width val ymin = outputLocations[0][i][1] * bitmap.height val xmax = outputLocations[0][i][2] * bitmap.width val ymax = outputLocations[0][i][3] * bitmap.height val boundingBox = RectF(xmin, ymin, xmax, ymax) results.add(DetectionResult(label, confidence, boundingBox)) } // 构建返回字符串,包含所有检测结果和推理时间 val resultBuilder = StringBuilder() resultBuilder.append("Inference Time: ${inferenceTime}ms\n") if (results.isEmpty()) { resultBuilder.append("No objects detected.") } else { results.forEachIndexed { index, result -> resultBuilder.append("${index + 1}. ${result.label} (${String.format("%.2f", result.confidence * 100)}%)\n") resultBuilder.append(" Bounding Box: (${result.boundingBox.left.toInt()}, ${result.boundingBox.top.toInt()}) - (${result.boundingBox.right.toInt()}, ${result.boundingBox.bottom.toInt()})\n") } } return resultBuilder.toString() }
3.3 注意事项
- 不同模型的输出张量结构可能不同,你可以使用Netron工具可视化模型的输入输出,确认张量的形状和含义。
- 边界框的坐标通常是相对输入图像的比例值,需要乘以原始图像的宽高转换为绝对坐标。
- 记得更新
labels.txt,确保类别列表和模型训练时的类别顺序一致。
总结
要获取边界框,必须更换为目标检测TFLite模型,然后根据新模型的输入输出结构调整代码逻辑,解析边界框、类别和置信度信息。
内容的提问来源于stack exchange,提问作者MrRobot9




