You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何无需转NumPy数组,用TensorFlow加载含图像掩码的分割数据集?

直接加载图像-掩码对作为TensorFlow分割数据集的简便方案

我完全懂你的痛点!tf.keras.preprocessing.image_dataset_from_directory确实只适合单目录带类别标签的场景,没法直接关联两个目录的图像作为输入-掩码对。不过用TensorFlow原生的tf.data API就能完美解决,全程不用碰NumPy数组,还能保持流水线的高效性。

核心思路

假设你的数据结构是这样的(如果不是,后面有适配方案):

  • train_images/:存放所有训练图像(比如img_001.jpgimg_002.png
  • train_masks/:存放对应掩码图像,文件名和训练图像完全一致img_001.jpgimg_002.png

我们的目标是把这两个目录里的文件一一配对,直接构建成(图像, 掩码)格式的tf.data.Dataset

步骤1:获取并匹配文件路径

首先用TensorFlow的文件操作工具获取所有文件路径,然后按文件名排序确保配对正确:

import tensorflow as tf
from pathlib import Path

# 定义数据目录路径
train_img_dir = Path("train_images")
train_mask_dir = Path("train_masks")

# 获取所有图像和掩码的路径(支持jpg/png格式)
img_paths = list(train_img_dir.glob("*.jpg")) + list(train_img_dir.glob("*.png"))
mask_paths = list(train_mask_dir.glob("*.jpg")) + list(train_mask_dir.glob("*.png"))

# 按文件名排序,保证图像和掩码一一对应
img_paths.sort()
mask_paths.sort()

如果你的掩码文件名和图像有固定后缀差异(比如图像是img_001.jpg,掩码是img_001_mask.jpg),可以直接通过路径替换生成掩码路径,不用手动排序:

# 用路径替换生成掩码路径,替代glob和排序
mask_paths = [train_mask_dir / f"{path.stem}_mask{path.suffix}" for path in img_paths]

步骤2:定义加载预处理函数

写一个函数来加载单对图像和掩码,处理解码、类型转换等操作(按需调整,比如掩码是否需要归一化):

def load_pair(img_path, mask_path):
    # 加载并预处理图像:解码+归一化到[0,1]
    img = tf.io.read_file(img_path)
    img = tf.image.decode_jpeg(img, channels=3)  # 用decode_png如果是PNG格式
    img = tf.image.convert_image_dtype(img, tf.float32)
    
    # 加载并预处理掩码:单通道+保持整数类型(因为掩码是类别标签)
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_jpeg(mask, channels=1)
    mask = tf.image.convert_image_dtype(mask, tf.uint8)  # 避免归一化,保留类别ID
    
    return img, mask

步骤3:构建高效的数据集流水线

tf.data把路径配对,映射加载函数,再加上打乱、批量、预取等优化:

# 从路径切片创建数据集,然后配对
img_ds = tf.data.Dataset.from_tensor_slices([str(p) for p in img_paths])
mask_ds = tf.data.Dataset.from_tensor_slices([str(p) for p in mask_paths])
train_ds = tf.data.Dataset.zip((img_ds, mask_ds))

# 映射加载函数,开启并行处理提升速度
train_ds = train_ds.map(load_pair, num_parallel_calls=tf.data.AUTOTUNE)

# 设置批量大小,打乱数据,预取数据避免训练时等待
batch_size = 8
train_ds = train_ds.shuffle(buffer_size=len(img_paths))
train_ds = train_ds.batch(batch_size)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

额外优化提示

  • 如果需要调整图像大小,可以在load_pair函数里加上tf.image.resize
  • 验证集可以用完全相同的逻辑构建
  • 这个方案全程在TensorFlow图模式下运行,支持分布式训练,性能比转NumPy数组再加载好很多

内容的提问来源于stack exchange,提问作者Rafael Sanjuan

火山引擎 最新活动