You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

使用Google LiteRT加载TensorFlow Lite导出模型时的兼容性错误及解决方案咨询

Google LiteRT加载TensorFlow Lite导出模型时的兼容性错误及解决方案咨询

问题概述

你提到之前用TensorFlow训练的图像分类模型,导出为TFLite格式后在Java代码中运行正常,但迁移到Kotlin并使用Google LiteRT(官方称其为TFLite的重命名版本)加载时,出现了模型兼容性错误。下面先梳理你的相关代码和错误信息,再给出针对性的解决建议。


1. 模型训练的Python代码

import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
import os

# 加载并预处理数据集
data = tf.keras.utils.image_dataset_from_directory('snails', image_size=(256,256), shuffle=True)
class_names = data.class_names
num_classes = len(class_names)
print("Classes:", class_names)
data = data.map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))
data = data.shuffle (5235) # 打乱所有数据
data = data.take(5235) # 使用全部数据
dataset_size = 5235 # 总数据量
train_size = int(3664) # 训练集大小(总数据*0.7向上取整)
val_size = int(524) # 验证集大小
test_size = 1047 # 测试集大小(总数据*0.2)
train = data.take(train_size)
val = data.skip(train_size).take(val_size)
test = data.skip(train_size + val_size).take(test_size)
AUTOTUNE = tf.data.AUTOTUNE
train = train.cache().prefetch(AUTOTUNE)
val = val.cache().prefetch(AUTOTUNE)
test = test.cache().prefetch(AUTOTUNE)

# 构建迁移学习模型
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
for layer in base_model.layers:
    layer.trainable = False
inputs = Input(shape=(256,256,3))
x = base_model(inputs)
x = GlobalAveragePooling2D()(x)
x = Dense(32, activation="relu", kernel_regularizer= l2(0.0005))(x)
x = Dense(64, activation="relu", kernel_regularizer= l2(0.0005))(x)
x = Dropout (0.3)(x)
predictions = Dense(num_classes, activation="softmax")(x)
model = Model(inputs=inputs, outputs=predictions)

# 编译并训练模型
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
logdir = 'logs'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
custom = model.fit(train, validation_data=val, epochs=2, callbacks=[tensorboard_callback])

# 微调模型
for layer in base_model.layers[-3:]:
    layer.trainable = True
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
finetune = model.fit(train, validation_data=val, epochs=4, initial_epoch=2, callbacks=[tensorboard_callback])

# 保存模型
model.save(os.path.join('models', 'snailVGG3.h5'))

2. Kotlin中LiteRT的依赖与分类器代码

Gradle依赖(Groovy格式)

litert = { module = "com.google.ai.edge.litert:litert", version.ref = "litert" }
litert-gpu = { module = "com.google.ai.edge.litert:litert-gpu", version.ref = "litertGpu" }
litert-metadata = { module = "com.google.ai.edge.litert:litert-metadata", version.ref = "litertMetadata" }
litert-support = { module = "com.google.ai.edge.litert:litert-support", version.ref = "litertSupport" }

Kotlin图像分类器代码

import android.content.Context
import android.graphics.Bitmap
import com.google.ai.edge.litert.Accelerator
import com.google.ai.edge.litert.CompiledModel
import com.google.ai.edge.litert.support.image.ImageProcessor
import com.google.ai.edge.litert.support.image.TensorImage
import com.google.ai.edge.litert.support.image.ops.NormalizeOp
import com.google.ai.edge.litert.support.image.ops.ResizeOp
import com.google.ai.edge.litert.support.tensorbuffer.TensorBuffer

data class Classification(val label: String, val confidence: Float)

class ImageClassifier(private val context: Context) {

    private var labels: List<String> = emptyList()
    private val modelInputWidth = 256
    private val modelInputHeight = 256
    private val threshold: Float= 0.9f
    private val maxResults: Int = 1

    private var imageProcessor = ImageProcessor.Builder()
        .add(ResizeOp(modelInputHeight,modelInputWidth, ResizeOp.ResizeMethod.BILINEAR))
        .add(NormalizeOp(0f,255f))
        .build()

    private var model: CompiledModel = CompiledModel.create(
        context.assets,
        "snailVGG2.tflite",
        CompiledModel.Options(Accelerator.CPU))
    init {
        labels = context.assets.open("snail_types.txt").bufferedReader().readLines()
    }

    fun classify(bitmap: Bitmap): List<Classification> {

        if (bitmap.width <= 0 || bitmap.height <= 0) return emptyList()

        val inputBuffer = model.createInputBuffers()
        val outputBuffer = model.createOutputBuffers()

        val tensorImage = TensorImage(TensorBuffer.DataType.FLOAT32).apply { load(bitmap) }

        val processedImage = imageProcessor.process(tensorImage)
        processedImage.buffer.rewind()

        val floatBuffer = processedImage.buffer.asFloatBuffer()
        val inputArray = FloatArray(1*256*256*3)
        floatBuffer.get(inputArray)

        inputBuffer[0].writeFloat(inputArray)

        model.run(inputBuffer, outputBuffer)

        val outputFloatArray = outputBuffer[0].readFloat()

        inputBuffer.forEach{it.close()}
        outputBuffer.forEach{it.close()}

        return outputFloatArray
            .mapIndexed {index, confidence -> Classification(labels[index], confidence) }
            .filter { it.confidence >= threshold }
            .sortedByDescending { it.confidence }
            .take(maxResults)
    }
}

3. 遇到的错误日志

[third_party/odml/litert/litert/runtime/tensor_buffer.cc:103] Failed to get num packed bytes
2025-12-18 04:15:19.894 25692-25692 tflite                  com.example.kuholifier_app           E  [third_party/odml/litert/litert/kotlin/src/main/jni/litert_compiled_model_jni.cc:538] Failed to create input buffers: ERROR: [third_party/odml/litert/litert/cc/litert_compiled_model.cc:123]
                                                                                                    └ ERROR: [third_party/odml/litert/litert/cc/litert_compiled_model.cc:82]
                                                                                                    └ ERROR: [third_party/odml/litert/litert/cc/litert_tensor_buffer.cc:49]

针对性解决建议

1. 优先切换回TensorFlow Lite依赖(最稳妥)

既然你之前的Java代码用TFLite能正常运行,迁移到Kotlin时直接沿用TFLite的Kotlin依赖是最直接的方案,完全兼容你已有的模型:

  • 替换Gradle依赖为TFLite官方依赖:
    implementation 'org.tensorflow:tensorflow-lite:2.15.0' // 选择和训练模型时匹配的版本
    implementation 'org.tensorflow:tensorflow-lite-support:0.4.4'
    
  • 调整Kotlin代码适配TFLite API,比如用Interpreter加载模型,这部分和你之前的Java代码逻辑基本一致,迁移成本很低。

2. 若坚持使用LiteRT,尝试以下修复步骤

  • 确认模型文件一致性:你Python代码保存的是snailVGG3.h5,但Kotlin代码里加载的是snailVGG2.tflite,要确保你导出的TFLite模型文件名正确,且确实是从训练好的snailVGG3.h5转换而来的。重新导出模型的代码示例:
    import tensorflow as tf
    
    model = tf.keras.models.load_model('models/snailVGG3.h5')
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    # 若需要量化可添加相关配置
    tflite_model = converter.convert()
    with open('snailVGG3.tflite', 'wb') as f:
        f.write(tflite_model)
    
  • 匹配LiteRT与TensorFlow版本:LiteRT的版本需要和你导出模型时使用的TensorFlow版本严格对应,比如如果LiteRT用的是2.16.0版本,导出模型时也要用TensorFlow 2.16.0,避免版本不兼容导致的模型格式差异。
  • 调整LiteRT模型加载参数:创建CompiledModel时,尝试去掉加速器指定,使用默认选项:
    private var model: CompiledModel = CompiledModel.create(
        context.assets,
        "snailVGG3.tflite"
    )
    
    或者尝试使用Accelerator.DEFAULT代替Accelerator.CPU

3. 检查模型输入输出格式

LiteRT对模型的输入输出张量格式可能有更严格的要求,你可以用TensorFlow Lite的interpreter.getInputTensor(0).shape()在Python中检查模型的输入形状,确保和Kotlin代码中设置的(256,256,3)完全一致,同时确认图像预处理的归一化逻辑(你代码里是NormalizeOp(0f,255f),和Python里的x/255.0是一致的,这部分没问题)。

总体来说,切换回TensorFlow Lite依赖是最快解决问题的方案,因为你已经验证过模型在TFLite上的可用性;如果想尝试LiteRT,重点排查版本匹配和模型加载的细节问题。

火山引擎 最新活动