如何在Android中实现预训练的TensorFlow TextSum模型?
确实,TensorFlow TextSum在Android端的实现资料比图像分类类的少太多了,我来一步步给你拆解具体的实现思路和步骤,帮你把这个模型跑起来:
这一步是基础,毕竟Android端主要用TensorFlow Lite(TFLite)来跑模型,得先把训练好的TextSum模型转成TFLite支持的格式:
导出SavedModel格式
首先要把你训练好的TextSum Seq2Seq模型导出成SavedModel格式。训练时要明确输入输出的张量签名,比如输入命名为input_text_ids,输出命名为summary_text_ids,这样后续转换和Android端调用时更清晰。用Python代码导出的示例:import tensorflow as tf # 假设你的模型已经训练好,加载模型 model = tf.keras.models.load_model('your_trained_textsum_model.h5') # 定义签名函数,明确输入输出 @tf.function(input_signature=[tf.TensorSpec(shape=(None, None), dtype=tf.int32, name='input_text_ids')]) def serve_fn(input_ids): return {'summary_text_ids': model(input_ids)} # 保存为SavedModel tf.saved_model.save(model, 'textsum_saved_model', signatures={'serving_default': serve_fn})转换为TFLite模型
用TFLite转换器把SavedModel转成.tflite文件,这里要注意TextSum是序列模型,可能涉及变长输入或自定义OP,所以转换时要开启允许自定义操作,还可以做量化优化来减小体积、提升速度:converter = tf.lite.TFLiteConverter.from_saved_model('textsum_saved_model') # 允许自定义OP(如果模型用了TFLite不原生支持的操作) converter.allow_custom_ops = True # 可选:开启动态范围量化,平衡精度和速度 converter.optimizations = [tf.lite.Optimize.DEFAULT] # 转换模型 tflite_model = converter.convert() # 保存.tflite文件 with open('textsum.tflite', 'wb') as f: f.write(tflite_model)
搞定模型后,就可以在Android项目里集成了:
添加依赖
在项目的app/build.gradle里添加TFLite相关依赖:dependencies { // 核心TFLite库 implementation 'org.tensorflow:tensorflow-lite:2.15.0' // 可选:TFLite支持库,简化张量处理 implementation 'org.tensorflow:tensorflow-lite-support:0.4.4' }放置模型与词汇表
把转换好的textsum.tflite和训练时用的vocab.txt(词汇表文件)放到src/main/assets目录下,同时确保app/build.gradle里配置了assets打包:android { sourceSets { main { assets.srcDirs = ['src/main/assets'] } } }文本预处理
TextSum模型需要把输入文本转换成整数序列,这一步必须和训练时的预处理逻辑完全一致:// 从assets加载词汇表,构建词到id的映射 private Map<String, Integer> loadVocab() throws IOException { Map<String, Integer> vocabMap = new HashMap<>(); BufferedReader reader = new BufferedReader( new InputStreamReader(getAssets().open("vocab.txt"))); String line; int index = 0; while ((line = reader.readLine()) != null) { vocabMap.put(line.trim(), index++); } reader.close(); return vocabMap; } // 把输入文本转成模型需要的整数序列,padding到固定长度 private int[] preprocessText(String inputText, Map<String, Integer> vocab, int maxInputLen) { // 这里的分词逻辑要和训练时一致!比如训练时用了jieba分词,这里也要用相同的工具 String[] tokens = inputText.split(" "); int[] inputIds = new int[maxInputLen]; Arrays.fill(inputIds, 0); // 0是padding的id for (int i = 0; i < Math.min(tokens.length, maxInputLen); i++) { // 1是未知词的id,训练时定义的 inputIds[i] = vocab.getOrDefault(tokens[i], 1); } return inputIds; }划重点:分词逻辑必须和训练对齐!比如训练时用了中文分词,Android端就得用对应的分词库(比如jieba的Android版本),不然模型输出会完全混乱。
模型推理与后处理
加载TFLite模型,执行推理,再把输出的整数序列转回文本:private TensorFlowLite.Interpreter tfliteInterpreter; // 初始化模型 private void initTextSumModel() throws IOException { MappedByteBuffer modelBuffer = FileUtil.loadMappedFile(this, "textsum.tflite"); // 配置多线程推理,提升速度 TensorFlowLite.Interpreter.Options options = new TensorFlowLite.Interpreter.Options(); options.setNumThreads(4); tfliteInterpreter = new TensorFlowLite.Interpreter(modelBuffer, options); } // 执行推理,生成摘要 private String generateSummary(String inputText) throws IOException { Map<String, Integer> vocab = loadVocab(); int maxInputLen = 512; // 和训练时的输入长度一致 int maxSummaryLen = 128; // 训练时定义的最大摘要长度 int vocabSize = vocab.size(); // 预处理输入文本 int[] inputIds = preprocessText(inputText, vocab, maxInputLen); // 准备输入张量(形状要和模型输入一致,比如[1, 512]) int[][] inputTensor = new int[1][maxInputLen]; inputTensor[0] = inputIds; // 准备输出张量(形状和模型输出一致,比如[1, 128, vocabSize]) float[][][] outputTensor = new float[1][maxSummaryLen][vocabSize]; tfliteInterpreter.run(inputTensor, outputTensor); // 后处理:把输出张量转成文本 StringBuilder summaryBuilder = new StringBuilder(); for (int i = 0; i < maxSummaryLen; i++) { int tokenId = argmax(outputTensor[0][i]); if (tokenId == 2) { // 2是结束符的id,训练时定义的 break; } String token = getTokenById(vocab, tokenId); if (token != null) { summaryBuilder.append(token).append(" "); } } return summaryBuilder.toString().trim(); } // 辅助函数:获取数组最大值的索引(对应概率最高的词id) private int argmax(float[] array) { int maxIndex = 0; float maxValue = array[0]; for (int i = 1; i < array.length; i++) { if (array[i] > maxValue) { maxValue = array[i]; maxIndex = i; } } return maxIndex; } // 辅助函数:根据id获取对应的词 private String getTokenById(Map<String, Integer> vocab, int id) { for (Map.Entry<String, Integer> entry : vocab.entrySet()) { if (entry.getValue() == id) { return entry.getKey(); } } return null; }资源释放
记得在页面销毁时释放TFLite解释器的资源,避免内存泄漏:@Override protected void onDestroy() { super.onDestroy(); if (tfliteInterpreter != null) { tfliteInterpreter.close(); } }
- 变长输入不支持:如果模型用了变长输入,转换TFLite时要确保开启
allow_custom_ops,或者在导出SavedModel时明确动态形状的签名。 - 量化导致精度下降:如果用了整数量化后摘要质量变差,可以尝试动态范围量化或浮点量化,平衡精度和性能。
- 自定义OP报错:如果模型用了TFLite不支持的自定义操作,要么替换成原生支持的OP,要么在Android端注册对应的自定义OP实现。
内容的提问来源于stack exchange,提问作者dhrumit patel




