You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何使用JPMML-Android将scikit-learn生成的PMML随机森林模型部署到Android?

手把手教你用JPMML-Android部署PMML随机森林模型到Android

别慌!作为新手看不懂JPMML-Android的Usage很正常,我给你梳理了一套从0到1的落地流程,跟着一步步来就行:

1. 先确保你的PMML模型是“合格”的

首先得确认你用scikit-learn导出的PMML模型是标准格式,推荐用sklearn2pmml库导出(如果还没装,先执行pip install sklearn2pmml),示例代码:

from sklearn.ensemble import RandomForestClassifier
from sklearn2pmml import sklearn2pmml
from sklearn2pmml.pipeline import PMMLPipeline

# 假设你已经训练好随机森林模型rf_model,并且有对应的特征处理步骤
pipeline = PMMLPipeline([
    ("classifier", rf_model)
])

# 导出PMML模型,文件名比如RandomForestModel.pmml
sklearn2pmml(pipeline, "RandomForestModel.pmml")

重点提醒:导出时要确保模型里的特征名、数据类型和你后面Android端要传入的完全一致,不然预测会报错!

2. 配置Android项目

2.1 添加依赖

打开你的Android项目,找到app/build.gradle(Module级别的那个),在dependencies块里添加JPMML-Android的依赖:

dependencies {
    // 替换成最新版本号,可去Maven仓库查最新版
    implementation 'org.jpmml:jpmml-android:1.2.0'
    // 如果用Kotlin协程处理后台任务,还要加这个(推荐)
    implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3'
}

然后同步Gradle。

2.2 放入PMML模型文件

app/src/main目录下新建assets文件夹(如果没有的话),把你导出的RandomForestModel.pmml复制进去。

3. 核心代码实现(Kotlin示例,Java逻辑类似)

3.1 加载PMML模型(必须在后台线程!)

加载模型是耗时操作,不能放在主线程,否则会触发ANR(应用无响应)。这里用Kotlin协程来处理:

import android.os.Bundle
import android.widget.Toast
import androidx.appcompat.app.AppCompatActivity
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.jpmml.evaluator.Evaluator
import org.jpmml.evaluator.android.LoaderUtil

class MainActivity : AppCompatActivity() {
    // 声明模型评估器
    private lateinit var evaluator: Evaluator

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        
        // 启动加载模型
        loadModel()
    }

    private fun loadModel() {
        GlobalScope.launch(Dispatchers.Main) {
            try {
                evaluator = withContext(Dispatchers.IO) {
                    // 从assets读取PMML文件并加载为Evaluator
                    val inputStream = assets.open("RandomForestModel.pmml")
                    LoaderUtil.loadEvaluator(inputStream)
                }
                Toast.makeText(this@MainActivity, "模型加载完成!", Toast.LENGTH_SHORT).show()
            } catch (e: Exception) {
                Toast.makeText(this@MainActivity, "模型加载失败:${e.message}", Toast.LENGTH_LONG).show()
            }
        }
    }
}

3.2 准备输入数据

根据你的模型特征,构造一个键值对Map,键是模型里的特征名,值是对应的数据类型(必须和模型定义一致!):

// 示例:假设你的模型有3个特征:age(整数), income(浮点数), score(浮点数)
private fun prepareInput(age: Int, income: Double, score: Float): Map<String, Any> {
    return mapOf(
        "age" to age,
        "income" to income,
        "score" to score
    )
}

3.3 执行预测并处理结果

同样,预测操作也建议放在后台线程,避免阻塞UI:

import org.jpmml.evaluator.android.EvaluatorUtil

private fun predict(input: Map<String, Any>): Any? {
    // 先检查模型是否加载完成
    if (!::evaluator.isInitialized) {
        return null
    }
    // 执行预测,结果是一个Map,键是目标变量名,值是预测结果
    val resultMap = EvaluatorUtil.evaluate(evaluator, input)
    // 取第一个结果(如果是分类模型,可能是类别或概率;回归模型是数值)
    return resultMap.values.firstOrNull()
}

// 比如在按钮点击事件里调用
button_predict.setOnClickListener {
    // 从UI输入框获取数据(这里要处理输入为空或格式错误的情况,示例省略)
    val age = edit_text_age.text.toString().toInt()
    val income = edit_text_income.text.toString().toDouble()
    val score = edit_text_score.text.toString().toFloat()
    
    val inputData = prepareInput(age, income, score)
    
    GlobalScope.launch(Dispatchers.Main) {
        val prediction = withContext(Dispatchers.IO) {
            predict(inputData)
        }
        prediction?.let {
            text_view_result.text = "预测结果:$it"
        } ?: run {
            Toast.makeText(this@MainActivity, "模型未加载完成,请稍后再试", Toast.LENGTH_SHORT).show()
        }
    }
}

4. 避坑指南

  • 数据类型严格匹配:如果模型里的特征是Integer,就不能传StringDouble,否则会抛出类型不匹配的异常。
  • 模型大小优化:如果你的PMML模型很大,可以用JPMML-Model工具压缩模型(比如移除冗余信息),加快加载速度。
  • 权限问题:如果模型放在外部存储,需要申请READ_EXTERNAL_STORAGE权限,但推荐直接放在assets目录,无需额外权限。
  • 异常处理:加载模型和预测时一定要加try-catch,避免崩溃。

内容的提问来源于stack exchange,提问作者Yuerno

火山引擎 最新活动