You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

Android中使用mobilenet_v1_1.0_224.tflite如何获取目标检测边界框?

解决Android TFLite获取边界框的问题

首先,你说得没错——当前的代码确实无法返回边界框信息,核心原因是你正在使用的是图像分类TFLite模型,这类模型的设计目标就是只输出图像(或图像区域)的类别概率,完全没有边界框相关的输出张量。要获取目标的边界框,你需要更换为目标检测TFLite模型,下面是具体的步骤和代码调整建议:

1. 明确模型类型差异

  • 图像分类模型:输入一张图像,输出每个类别的概率值(就是你当前代码里的FloatArray),只回答"这是什么"。
  • 目标检测模型:输入一张图像,输出多个检测结果,每个结果包含:边界框坐标(xmin, ymin, xmax, ymax)、类别ID、置信度(该检测结果的可靠程度),能回答"这是什么,在哪里"。

2. 更换为目标检测TFLite模型

你需要选择一个适合移动端的目标检测TFLite模型,比如:

  • SSD-MobileNet系列(轻量,适合移动端)
  • EfficientDet-Lite系列(精度和速度平衡)
  • YOLOv5/YOLOv8的TFLite版本(需要自己导出或找预编译版本)

注意:下载模型时要确保是TFLite格式(.tflite),同时配套对应的labels.txt(类别列表要和模型训练时的一致)。

3. 修改代码适配目标检测模型

目标检测模型的输入输出结构和分类模型完全不同,需要调整你的代码逻辑:

3.1 调整模型初始化与输入参数

首先,你需要确认新模型的输入尺寸(比如SSD-MobileNet通常是300x300,EfficientDet-Lite0是320x320),替换原来的inputImageWidthinputImageHeight

3.2 修改检测逻辑(替换原classify函数)

目标检测模型的输出张量结构因模型而异,举个常见的例子(以SSD-MobileNet V2为例):

  • 输出通常包含四个张量:边界框坐标、类别概率、置信度、实际检测目标数。

下面是适配这类模型的示例代码:

// 定义数据类存储检测结果
data class DetectionResult(
    val label: String,
    val confidence: Float,
    val boundingBox: RectF // 存储边界框的绝对坐标
)

private fun detectObjects(bitmap: Bitmap): String {
    check(isInitialized) { "TF Lite Interpreter is not initialized yet." }
    
    // 调整输入图像尺寸为模型要求的大小
    val resizedImage = Bitmap.createScaledBitmap(bitmap, inputImageWidth, inputImageHeight, true)
    val byteBuffer = convertBitmapToByteBuffer(resizedImage)
    
    // 根据模型输出结构定义输出张量,这里以SSD-MobileNet V2为例
    val outputLocations = Array(1) { Array(10) { FloatArray(4) } } // 边界框坐标
    val outputClasses = Array(1) { Array(10) { FloatArray(91) } } // 类别概率
    val outputScores = Array(1) { FloatArray(10) } // 置信度
    val numDetections = FloatArray(1) // 实际检测到的目标数
    
    val startTime = SystemClock.uptimeMillis()
    // 运行模型,注意输出张量的顺序要和模型定义一致
    interpreter?.runForMultipleInputsOutputs(
        arrayOf(byteBuffer),
        mapOf(
            0 to outputLocations,
            1 to outputClasses,
            2 to outputScores,
            3 to numDetections
        )
    )
    val endTime = SystemClock.uptimeMillis()
    val inferenceTime = endTime - startTime
    
    // 解析检测结果
    val results = mutableListOf<DetectionResult>()
    val numDetected = numDetections[0].toInt()
    for (i in 0 until numDetected) {
        val confidence = outputScores[0][i]
        // 过滤低置信度的结果,比如只保留置信度>0.5的
        if (confidence < 0.5) continue
        
        // 获取类别ID
        val classIndex = outputClasses[0][i].indexOfFirst { it == outputClasses[0][i].max() }
        val label = labels.getOrElse(classIndex) { "Unknown" }
        
        // 将相对坐标转换为原始图像的绝对坐标
        val xmin = outputLocations[0][i][0] * bitmap.width
        val ymin = outputLocations[0][i][1] * bitmap.height
        val xmax = outputLocations[0][i][2] * bitmap.width
        val ymax = outputLocations[0][i][3] * bitmap.height
        val boundingBox = RectF(xmin, ymin, xmax, ymax)
        
        results.add(DetectionResult(label, confidence, boundingBox))
    }
    
    // 构建返回字符串,包含所有检测结果和推理时间
    val resultBuilder = StringBuilder()
    resultBuilder.append("Inference Time: ${inferenceTime}ms\n")
    if (results.isEmpty()) {
        resultBuilder.append("No objects detected.")
    } else {
        results.forEachIndexed { index, result ->
            resultBuilder.append("${index + 1}. ${result.label} (${String.format("%.2f", result.confidence * 100)}%)\n")
            resultBuilder.append("  Bounding Box: (${result.boundingBox.left.toInt()}, ${result.boundingBox.top.toInt()}) - (${result.boundingBox.right.toInt()}, ${result.boundingBox.bottom.toInt()})\n")
        }
    }
    return resultBuilder.toString()
}

3.3 注意事项

  • 不同模型的输出张量结构可能不同,你可以使用Netron工具可视化模型的输入输出,确认张量的形状和含义。
  • 边界框的坐标通常是相对输入图像的比例值,需要乘以原始图像的宽高转换为绝对坐标。
  • 记得更新labels.txt,确保类别列表和模型训练时的类别顺序一致。

总结

要获取边界框,必须更换为目标检测TFLite模型,然后根据新模型的输入输出结构调整代码逻辑,解析边界框、类别和置信度信息。

内容的提问来源于stack exchange,提问作者MrRobot9

火山引擎 最新活动