You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

如何修改ClassifierActivity.java以接入自定义灰度图像神经网络

我最近刚把自己训练的灰度图像模型接入到这个TensorFlow Android Demo里,踩了几个坑,给你整理一下具体要改的地方:

1. 调整图像预处理逻辑:把RGB转成单通道灰度

Demo默认是按RGB三通道给模型喂数据,我们需要改成单通道灰度。找到ClassifierActivity里处理Bitmap像素的代码段(一般在和输入ByteBuffer填充相关的部分),把原来的RGB提取逻辑替换成灰度计算:

原来的代码大概是这样(提取RGB三通道):

int pixel = bitmap.getPixel(x, y);
inputBuffer.put((byte) Color.red(pixel));
inputBuffer.put((byte) Color.green(pixel));
inputBuffer.put((byte) Color.blue(pixel));

改成灰度计算(用标准的灰度转换公式):

int pixel = bitmap.getPixel(x, y);
// 计算灰度值:Y = 0.299*R + 0.587*G + 0.114*B
int grayValue = (int) (0.299 * Color.red(pixel) + 0.587 * Color.green(pixel) + 0.114 * Color.blue(pixel));
// 这里要和你训练模型时的输入格式匹配:如果是[0,255]的byte输入,直接put;如果是归一化的float输入,要做转换
inputBuffer.put((byte) grayValue);
2. 修改输入张量的通道数和缓冲区容量

因为灰度图是单通道,要把原来的三通道配置改成单通道:

  • 找到定义输入尺寸的常量(比如在TensorFlowImageClassifier类或者ClassifierActivity里),把INPUT_CHANNELS从3改成1
  • 同步调整输入ByteBuffer的容量计算:原来的是1 * INPUT_SIZE * INPUT_SIZE * 3 * BYTES_PER_CHANNEL,现在改成1 * INPUT_SIZE * INPUT_SIZE * 1 * BYTES_PER_CHANNEL

举个代码例子:

// 修改前
private static final int INPUT_CHANNELS = 3;
// 修改后
private static final int INPUT_CHANNELS = 1;

// 对应的ByteBuffer初始化也要改
inputBuffer = ByteBuffer.allocateDirect(
    1 * INPUT_SIZE * INPUT_SIZE * INPUT_CHANNELS * BYTES_PER_CHANNEL);
3. (可选)直接用Camera的YUV数据提升效率

如果Camera输出的是NV21格式的YUV图像,其实Y分量本身就是灰度,直接提取Y通道可以跳过RGB转灰度的步骤,节省计算资源:

// 假设拿到的Camera帧数据是nv21Bytes,宽高为previewWidth、previewHeight
byte[] grayBytes = new byte[previewWidth * previewHeight];
// NV21格式里Y通道是数据的前previewWidth*previewHeight字节
System.arraycopy(nv21Bytes, 0, grayBytes, 0, previewWidth * previewHeight);
// 接下来把grayBytes缩放到模型需要的INPUT_SIZE尺寸,再填充到输入缓冲区
4. 确保预处理和训练时完全一致

这一步很关键,不然模型输出会乱:

  • 如果训练时你把灰度值归一化到了[0,1]或者[-1,1],那代码里也要做同样的转换,比如:
    float normalizedGray = grayValue / 255.0f; // 归一化到[0,1]
    inputBuffer.putFloat(normalizedGray);
    
    这时候BYTES_PER_CHANNEL要改成4(因为float是4字节)
  • 检查图像是否需要翻转(Camera预览一般是镜像的,如果你训练时用的是正方向图像,要做镜像翻转)
  • 确认模型的输入节点名称和代码里指定的一致(比如原来的Demo用"input",你的模型如果是"input_1",要在代码里修改)

内容的提问来源于stack exchange,提问作者s11

火山引擎 最新活动