You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

如何将训练好的Python machine learning model接入原生Android App?

如何将训练好的Python machine learning model接入原生Android App?

Hey there! 我之前也折腾过把Python训练的ML模型搬到Android上,踩过不少坑,给你分享几个靠谱的、亲测有效的方案,一步步来都能搞定:

方案1:用TensorFlow Lite(最常用,适配TF/Keras模型)

如果你的模型是用TensorFlow或Keras训练的,这绝对是首选方案,步骤清晰且官方支持完善:

  • 第一步:Python端将模型转换为TFLite格式
    打开你的训练脚本,添加几行代码就能完成转换:
    # 假设你已经训练好名为model的Keras/TF模型
    import tensorflow as tf
    
    # 初始化转换器
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    # 可选:开启量化优化,大幅减小模型体积并提升运行速度
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    # 执行转换
    tflite_model = converter.convert()
    
    # 保存为.tflite文件
    with open("your_model.tflite", "wb") as f:
        f.write(tflite_model)
    
  • 第二步:将模型文件放入Android项目
    在Android Studio中,把生成的your_model.tflite复制到app/src/main/assets目录下(如果没有assets文件夹,右键app模块 → New → Folder → Assets Folder创建即可)。
  • 第三步:Android端加载并运行模型
    首先在Module级别的build.gradle中添加依赖:
    dependencies {
        implementation 'org.tensorflow:tensorflow-lite:2.15.0'
        // 可选:引入TFLite Support库,简化数据预处理/后处理
        implementation 'org.tensorflow:tensorflow-lite-support:0.4.4'
    }
    
    然后用Kotlin/Java编写模型加载与推理逻辑,这里是Kotlin示例:
    import org.tensorflow.lite.Interpreter
    import java.io.FileInputStream
    import java.nio.MappedByteBuffer
    import java.nio.channels.FileChannel
    
    class TFLiteModelHelper(private val context: Context) {
        private lateinit var tfliteInterpreter: Interpreter
    
        init {
            loadModel()
        }
    
        private fun loadModel() {
            val modelBuffer = loadModelFromAssets()
            tfliteInterpreter = Interpreter(modelBuffer)
        }
    
        private fun loadModelFromAssets(): MappedByteBuffer {
            val assetFd = context.assets.openFd("your_model.tflite")
            val inputStream = FileInputStream(assetFd.fileDescriptor)
            val fileChannel = inputStream.channel
            return fileChannel.map(
                FileChannel.MapMode.READ_ONLY,
                assetFd.startOffset,
                assetFd.declaredLength
            )
        }
    
        // 推理函数,需根据你的模型输入输出维度调整
        fun runInference(inputData: FloatArray): FloatArray {
            // 示例:假设模型输出是单个数值,对应长度为1的FloatArray
            val output = FloatArray(1)
            tfliteInterpreter.run(inputData, output)
            return output
        }
    }
    

方案2:用ONNX Runtime(适配多框架模型,如PyTorch、Scikit-learn)

如果你的模型是用PyTorch、Scikit-learn等非TF框架训练的,先转成ONNX格式,再用Android版ONNX Runtime运行是最优解:

  • 第一步:Python端将模型转ONNX格式
    以PyTorch模型为例:
    import torch
    
    # 加载训练好的PyTorch模型
    model = torch.load("your_model.pth")
    model.eval()
    
    # 构造与模型实际输入维度一致的示例输入
    example_input = torch.randn(1, 3, 224, 224)  # 示例:ResNet类模型的输入形状
    
    # 导出为ONNX格式
    torch.onnx.export(
        model,
        example_input,
        "your_model.onnx",
        input_names=["input"],
        output_names=["output"],
        opset_version=11  # 选择兼容Android Runtime的opset版本,11+均可
    )
    
  • 第二步:Android项目集成ONNX Runtime
    在Module级build.gradle中添加依赖:
    dependencies {
        implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.15.1'
    }
    
  • 第三步:Android端加载模型并推理
    Kotlin示例代码:
    import ai.onnxruntime.OrtEnvironment
    import ai.onnxruntime.OrtSession
    import java.io.File
    
    class ONNXModelHelper(private val context: Context) {
        private lateinit var session: OrtSession
        private val env = OrtEnvironment.getEnvironment()
    
        init {
            loadModel()
        }
    
        private fun loadModel() {
            // 将assets中的模型复制到内部存储(ONNX Runtime需要文件路径,无法直接读取assets流)
            val targetFile = File(context.filesDir, "your_model.onnx")
            if (!targetFile.exists()) {
                context.assets.open("your_model.onnx").use { input ->
                    targetFile.outputStream().use { output ->
                        input.copyTo(output)
                    }
                }
            }
            session = env.createSession(targetFile.absolutePath)
        }
    
        fun runInference(inputData: FloatArray): FloatArray {
            // 构造输入张量,需与模型输入维度匹配
            val inputTensor = OrtSession.Result.Value.createTensor(
                env,
                inputData,
                longArrayOf(1, 3, 224, 224)
            )
            val output = session.run(mapOf("input" to inputTensor))
            val result = output.get(0).value as FloatArray
            // 记得释放资源
            output.close()
            inputTensor.close()
            return result
        }
    }
    

方案3:用PyTorch Mobile(专为PyTorch模型打造)

如果你的模型是纯PyTorch训练的,用PyTorch Mobile可以跳过格式转换的麻烦,直接运行TorchScript模型:

  • 第一步:Python端将模型转TorchScript格式
    import torch
    
    model = torch.load("your_model.pth")
    model.eval()
    
    # 构造示例输入
    example_input = torch.randn(1, 3, 224, 224)
    # 追踪模型生成TorchScript
    scripted_model = torch.jit.trace(model, example_input)
    # 保存为.ptl文件
    scripted_model.save("your_model.ptl")
    
  • 第二步:Android集成PyTorch Mobile
    在Module级build.gradle中添加依赖:
    dependencies {
        implementation 'org.pytorch:pytorch_android:1.13.1'
        implementation 'org.pytorch:pytorch_android_torchvision:1.13.1'
    }
    
  • 第三步:Android端加载模型并推理
    Kotlin示例:
    import org.pytorch.IValue
    import org.pytorch.Module
    import org.pytorch.torchvision.TensorImageUtils
    import android.graphics.Bitmap
    
    class PyTorchModelHelper(private val context: Context) {
        private lateinit var module: Module
    
        init {
            // 直接加载assets中的TorchScript模型
            module = Module.load(context.assets.openFd("your_model.ptl"))
        }
    
        // 如果你的模型输入是图片,用这个方法更便捷
        fun runInference(bitmap: Bitmap): FloatArray {
            // 用TorchVision工具将Bitmap转换为模型所需的张量
            val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
                bitmap,
                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
                TensorImageUtils.TORCHVISION_NORM_STD_RGB
            )
            val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
            return outputTensor.dataAsFloatArray
        }
    }
    

避坑小提醒

  • 模型量化:不管用哪个方案,都建议开启模型量化(比如TFLite的Optimize.DEFAULT、ONNX的量化工具),能大幅减小模型体积,同时提升Android设备上的运行速度,对中低端设备效果尤其明显。
  • 输入输出匹配:务必确保Android端构造的输入张量数据类型(如float32)、维度形状(如[1,28,28])与Python训练时的输入完全一致,否则会出现推理错误或崩溃。
  • 资源释放:Android端加载模型后,记得在合适的时机(如Activity销毁时)释放模型资源,避免内存泄漏。

如果你的模型是Scikit-learn这类传统机器学习模型,建议先转成ONNX格式,再用方案2的ONNX Runtime方案。有具体模型框架或遇到具体报错的话,随时说,我再给你细化解决!

火山引擎 最新活动