You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何将Google Cloud训练的TensorFlow.pb模型部署到Android本地

嗨,这事儿我熟!把你在GCP上训练好的TensorFlow .pb模型部署到Android本地运行完全没问题,我给你一步步捋清楚:

步骤1:先把模型转成更适合移动设备的格式(可选但强烈推荐)

虽然直接用原始的.pb文件也能通过TensorFlow Mobile在Android上跑,但**TensorFlow Lite(TFLite)**是谷歌专为移动/嵌入式设备优化的框架,模型体积更小、推理速度更快,还支持硬件加速,所以优先建议把.pb转成.tflite格式。

转换步骤很简单,用Python脚本就能搞定:

  1. 先确保装了TensorFlow:pip install tensorflow
  2. 写个转换脚本(根据你的模型类型选对应方式):
    • 如果是SavedModel格式的.pb(从GCP训练导出的通常是这种):
    import tensorflow as tf
    from tensorflow.lite.python import lite
    
    # 加载本地的SavedModel
    saved_model = tf.saved_model.load("/你的模型路径/")
    # 获取默认的推理签名
    concrete_func = saved_model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
    
    # 转换为TFLite模型
    converter = lite.TFLiteConverter.from_concrete_functions([concrete_func])
    # 可选:开启优化(比如量化,进一步缩小体积、提升速度)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model = converter.convert()
    
    # 保存.tflite文件
    with open("model.tflite", "wb") as f:
        f.write(tflite_model)
    
    • 如果是冻结图(Frozen Graph)格式的.pb,就用lite.TFLiteConverter.from_frozen_graph(),需要提前用TensorBoard或者tf.compat.v1.get_default_graph()查看模型的输入、输出节点名称。
步骤2:配置Android项目

打开Android Studio,给你的项目加上TFLite支持:

  1. app/build.gradledependencies块里添加依赖:
    dependencies {
        // 核心TFLite库
        implementation 'org.tensorflow:tensorflow-lite:2.15.0'
        // 可选:TFLite支持库,简化输入输出处理
        implementation 'org.tensorflow:tensorflow-lite-support:0.4.4'
    }
    
  2. 把转换好的model.tflite(或者原始.pb文件)放到app/src/main/assets目录下(如果没有assets文件夹就新建一个)。
  3. 为了避免模型文件被压缩,在app/build.gradleandroid块里加上:
    android {
        ...
        aaptOptions {
            noCompress "tflite" // 不对.tflite文件压缩
            noCompress "pb"    // 如果用原始.pb的话加这行
        }
    }
    
步骤3:在Android代码里加载并运行模型

这里以TFLite为例,写一个简单的模型管理类,方便调用:

import org.tensorflow.lite.Interpreter;
import android.content.Context;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;

public class TFLiteModelManager {
    private Interpreter tfliteInterpreter;

    // 初始化模型
    public TFLiteModelManager(Context context) throws IOException {
        // 从assets加载模型文件
        MappedByteBuffer modelBuffer = loadModelFromAssets(context);
        // 配置Interpreter,比如开启NNAPI硬件加速
        Interpreter.Options options = new Interpreter.Options();
        options.setUseNNAPI(true);
        options.setNumThreads(4); // 用4线程加速推理
        tfliteInterpreter = new Interpreter(modelBuffer, options);
    }

    // 读取assets里的模型文件
    private MappedByteBuffer loadModelFromAssets(Context context) throws IOException {
        FileInputStream inputStream = new FileInputStream(
            context.getAssets().openFd("model.tflite").getFileDescriptor()
        );
        FileChannel channel = inputStream.getChannel();
        long startOffset = context.getAssets().openFd("model.tflite").getStartOffset();
        long fileLength = context.getAssets().openFd("model.tflite").getDeclaredLength();
        return channel.map(FileChannel.MapMode.READ_ONLY, startOffset, fileLength);
    }

    // 推理方法,根据你的模型输入输出调整参数
    public float[] runInference(float[] modelInput) {
        // 初始化输出数组,要和模型的输出维度完全匹配
        float[] modelOutput = new float[/* 替换成你的输出维度大小,比如分类任务的类别数 */];
        tfliteInterpreter.run(modelInput, modelOutput);
        return modelOutput;
    }

    // 记得在页面销毁时关闭Interpreter,避免内存泄漏
    public void closeModel() {
        if (tfliteInterpreter != null) {
            tfliteInterpreter.close();
        }
    }
}

如果非要用原始的.pb文件(不推荐,因为TensorFlow Mobile已停止维护),可以把依赖换成implementation 'org.tensorflow:tensorflow-android:1.15.0',然后用TensorFlowInferenceInterface类来加载和推理。

步骤4:处理输入输出数据

这一步很关键,必须和你训练模型时的预处理/后处理逻辑保持一致:

  • 比如你的模型是图像分类,就得把Android的Bitmap转换成模型需要的格式:调整尺寸、归一化像素值、转换通道顺序(比如从ARGB转RGB)。
  • 输出数据要根据任务解析:分类任务取概率最高的类别,检测任务解析边界框坐标,等等。

举个图像预处理的例子:

public float[] preprocessBitmap(Bitmap bitmap) {
    // 把Bitmap缩放到模型要求的尺寸,比如224x224
    Bitmap resizedBitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);
    int[] pixels = new int[224 * 224];
    resizedBitmap.getPixels(pixels, 0, 224, 0, 0, 224, 224);
    
    // 转换成模型需要的float数组(假设是RGB三通道,归一化到0-1)
    float[] inputArray = new float[224 * 224 * 3];
    for (int i = 0; i < pixels.length; i++) {
        int pixel = pixels[i];
        inputArray[i * 3] = ((pixel >> 16) & 0xFF) / 255.0f;   // R通道
        inputArray[i * 3 + 1] = ((pixel >> 8) & 0xFF) / 255.0f; // G通道
        inputArray[i * 3 + 2] = (pixel & 0xFF) / 255.0f;        // B通道
    }
    return inputArray;
}
一些额外的注意事项
  • 权限:如果模型需要处理摄像头图像,记得在AndroidManifest.xml里添加<uses-permission android:name="android.permission.CAMERA"/>,并在代码里动态申请权限。
  • 性能优化:除了NNAPI,还可以开启GPU加速(需要在Interpreter配置里设置options.setGpuDelegateEnabled(true)),不过要注意部分设备兼容性。
  • 内存管理:Interpreter对象比较占内存,建议用单例模式创建,不用的时候一定要调用close()释放资源。

内容的提问来源于stack exchange,提问作者Shubham Shekhar

火山引擎 最新活动