You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何在Android中实现预训练的TensorFlow TextSum模型?

确实,TensorFlow TextSum在Android端的实现资料比图像分类类的少太多了,我来一步步给你拆解具体的实现思路和步骤,帮你把这个模型跑起来:

一、先搞定TextSum模型的转换与优化

这一步是基础,毕竟Android端主要用TensorFlow Lite(TFLite)来跑模型,得先把训练好的TextSum模型转成TFLite支持的格式:

  • 导出SavedModel格式
    首先要把你训练好的TextSum Seq2Seq模型导出成SavedModel格式。训练时要明确输入输出的张量签名,比如输入命名为input_text_ids,输出命名为summary_text_ids,这样后续转换和Android端调用时更清晰。用Python代码导出的示例:

    import tensorflow as tf
    
    # 假设你的模型已经训练好,加载模型
    model = tf.keras.models.load_model('your_trained_textsum_model.h5')
    # 定义签名函数,明确输入输出
    @tf.function(input_signature=[tf.TensorSpec(shape=(None, None), dtype=tf.int32, name='input_text_ids')])
    def serve_fn(input_ids):
        return {'summary_text_ids': model(input_ids)}
    # 保存为SavedModel
    tf.saved_model.save(model, 'textsum_saved_model', signatures={'serving_default': serve_fn})
    
  • 转换为TFLite模型
    用TFLite转换器把SavedModel转成.tflite文件,这里要注意TextSum是序列模型,可能涉及变长输入或自定义OP,所以转换时要开启允许自定义操作,还可以做量化优化来减小体积、提升速度:

    converter = tf.lite.TFLiteConverter.from_saved_model('textsum_saved_model')
    # 允许自定义OP(如果模型用了TFLite不原生支持的操作)
    converter.allow_custom_ops = True
    # 可选:开启动态范围量化,平衡精度和速度
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    # 转换模型
    tflite_model = converter.convert()
    # 保存.tflite文件
    with open('textsum.tflite', 'wb') as f:
        f.write(tflite_model)
    
二、Android端的集成步骤

搞定模型后,就可以在Android项目里集成了:

  • 添加依赖
    在项目的app/build.gradle里添加TFLite相关依赖:

    dependencies {
        // 核心TFLite库
        implementation 'org.tensorflow:tensorflow-lite:2.15.0'
        // 可选:TFLite支持库,简化张量处理
        implementation 'org.tensorflow:tensorflow-lite-support:0.4.4'
    }
    
  • 放置模型与词汇表
    把转换好的textsum.tflite和训练时用的vocab.txt(词汇表文件)放到src/main/assets目录下,同时确保app/build.gradle里配置了assets打包:

    android {
        sourceSets {
            main {
                assets.srcDirs = ['src/main/assets']
            }
        }
    }
    
  • 文本预处理
    TextSum模型需要把输入文本转换成整数序列,这一步必须和训练时的预处理逻辑完全一致:

    // 从assets加载词汇表,构建词到id的映射
    private Map<String, Integer> loadVocab() throws IOException {
        Map<String, Integer> vocabMap = new HashMap<>();
        BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("vocab.txt")));
        String line;
        int index = 0;
        while ((line = reader.readLine()) != null) {
            vocabMap.put(line.trim(), index++);
        }
        reader.close();
        return vocabMap;
    }
    
    // 把输入文本转成模型需要的整数序列,padding到固定长度
    private int[] preprocessText(String inputText, Map<String, Integer> vocab, int maxInputLen) {
        // 这里的分词逻辑要和训练时一致!比如训练时用了jieba分词,这里也要用相同的工具
        String[] tokens = inputText.split(" ");
        int[] inputIds = new int[maxInputLen];
        Arrays.fill(inputIds, 0); // 0是padding的id
        for (int i = 0; i < Math.min(tokens.length, maxInputLen); i++) {
            // 1是未知词的id,训练时定义的
            inputIds[i] = vocab.getOrDefault(tokens[i], 1);
        }
        return inputIds;
    }
    

    划重点:分词逻辑必须和训练对齐!比如训练时用了中文分词,Android端就得用对应的分词库(比如jieba的Android版本),不然模型输出会完全混乱。

  • 模型推理与后处理
    加载TFLite模型,执行推理,再把输出的整数序列转回文本:

    private TensorFlowLite.Interpreter tfliteInterpreter;
    
    // 初始化模型
    private void initTextSumModel() throws IOException {
        MappedByteBuffer modelBuffer = FileUtil.loadMappedFile(this, "textsum.tflite");
        // 配置多线程推理,提升速度
        TensorFlowLite.Interpreter.Options options = new TensorFlowLite.Interpreter.Options();
        options.setNumThreads(4);
        tfliteInterpreter = new TensorFlowLite.Interpreter(modelBuffer, options);
    }
    
    // 执行推理,生成摘要
    private String generateSummary(String inputText) throws IOException {
        Map<String, Integer> vocab = loadVocab();
        int maxInputLen = 512; // 和训练时的输入长度一致
        int maxSummaryLen = 128; // 训练时定义的最大摘要长度
        int vocabSize = vocab.size();
    
        // 预处理输入文本
        int[] inputIds = preprocessText(inputText, vocab, maxInputLen);
        // 准备输入张量(形状要和模型输入一致,比如[1, 512])
        int[][] inputTensor = new int[1][maxInputLen];
        inputTensor[0] = inputIds;
    
        // 准备输出张量(形状和模型输出一致,比如[1, 128, vocabSize])
        float[][][] outputTensor = new float[1][maxSummaryLen][vocabSize];
        tfliteInterpreter.run(inputTensor, outputTensor);
    
        // 后处理:把输出张量转成文本
        StringBuilder summaryBuilder = new StringBuilder();
        for (int i = 0; i < maxSummaryLen; i++) {
            int tokenId = argmax(outputTensor[0][i]);
            if (tokenId == 2) { // 2是结束符的id,训练时定义的
                break;
            }
            String token = getTokenById(vocab, tokenId);
            if (token != null) {
                summaryBuilder.append(token).append(" ");
            }
        }
        return summaryBuilder.toString().trim();
    }
    
    // 辅助函数:获取数组最大值的索引(对应概率最高的词id)
    private int argmax(float[] array) {
        int maxIndex = 0;
        float maxValue = array[0];
        for (int i = 1; i < array.length; i++) {
            if (array[i] > maxValue) {
                maxValue = array[i];
                maxIndex = i;
            }
        }
        return maxIndex;
    }
    
    // 辅助函数:根据id获取对应的词
    private String getTokenById(Map<String, Integer> vocab, int id) {
        for (Map.Entry<String, Integer> entry : vocab.entrySet()) {
            if (entry.getValue() == id) {
                return entry.getKey();
            }
        }
        return null;
    }
    
  • 资源释放
    记得在页面销毁时释放TFLite解释器的资源,避免内存泄漏:

    @Override
    protected void onDestroy() {
        super.onDestroy();
        if (tfliteInterpreter != null) {
            tfliteInterpreter.close();
        }
    }
    
三、常见坑与解决方法
  • 变长输入不支持:如果模型用了变长输入,转换TFLite时要确保开启allow_custom_ops,或者在导出SavedModel时明确动态形状的签名。
  • 量化导致精度下降:如果用了整数量化后摘要质量变差,可以尝试动态范围量化或浮点量化,平衡精度和性能。
  • 自定义OP报错:如果模型用了TFLite不支持的自定义操作,要么替换成原生支持的OP,要么在Android端注册对应的自定义OP实现。

内容的提问来源于stack exchange,提问作者dhrumit patel

火山引擎 最新活动