You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何基于TensorFlow官方CIFAR-10示例自定义Dataset数据集?

Hey there! Let's walk through how to build your custom creat_dataset() function to replace the official CIFAR-10 dataset loading code in that ResNet example. Here's a practical, step-by-step guide that aligns with the original code's expectations:

1. First, Understand the Original Code's Requirements

The official code uses tf.data.FixedLengthRecordDataset to load CIFAR-10's binary files, where each record outputs a tuple of (image_tensor, label_tensor):

  • Image: A 32x32x3 float32 tensor normalized to [0, 1]
  • Label: An int32 scalar representing the class index

Your custom creat_dataset() must return a tf.data.Dataset that produces this exact structure—otherwise, the downstream preprocessing and model training logic will break.

Pro tip: Keep the original code's is_training flag as a parameter in your function. It's used to toggle data augmentation and shuffling for training vs. evaluation.

2. Common Implementation Scenarios

Below are two typical ways to build your custom dataset, depending on your data source:

Case 1: Loading from Image Files (PNG/JPG)

If your data is stored as individual image files (e.g., organized into class-specific folders), use this approach:

import tensorflow as tf
import os

def creat_dataset(is_training=True, batch_size=32):
    # Replace with your actual data directory
    data_dir = "./your_custom_data"
    
    # 1. List all image files, shuffle if training
    file_pattern = os.path.join(data_dir, "train/*/*.png") if is_training else os.path.join(data_dir, "test/*/*.png")
    dataset = tf.data.Dataset.list_files(file_pattern, shuffle=is_training)
    
    # 2. Define a function to parse images and labels from file paths
    def parse_image(file_path):
        # Extract label from folder name (e.g., "./data/train/class_2/img.png" → label=2)
        label = tf.strings.split(file_path, os.sep)[-2]
        label = tf.strings.to_number(label, out_type=tf.int32)
        
        # Read and process the image
        image = tf.io.read_file(file_path)
        image = tf.image.decode_png(image, channels=3)  # Use decode_jpeg if your images are JPG
        image = tf.cast(image, tf.float32) / 255.0  # Normalize to [0, 1]
        image = tf.image.resize(image, (32, 32))  # Match CIFAR-10's 32x32 size
        
        return image, label
    
    # 3. Apply parsing with parallel processing
    dataset = dataset.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
    
    # 4. Add training-specific augmentation (reuse the original code's logic!)
    if is_training:
        # Copy the data_augmentation function from the official CIFAR-10 example
        def data_augmentation(image):
            image = tf.image.random_crop(image, [32, 32, 3])
            image = tf.image.random_flip_left_right(image)
            return image
        
        dataset = dataset.map(
            lambda image, label: (data_augmentation(image), label),
            num_parallel_calls=tf.data.AUTOTUNE
        )
    
    # 5. Batch and prefetch for performance
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset

Case 2: Loading from Custom Binary Files

If your data is stored in a binary format similar to CIFAR-10, adapt the original code's logic like this:

import tensorflow as tf

def creat_dataset(is_training=True, batch_size=32):
    # Replace with your binary file paths
    filenames = ["./your_data/train.bin"] if is_training else ["./your_data/test.bin"]
    # Define the byte size of each record (adjust based on your data structure)
    # Example: 1 byte label + 32*32*3 bytes pixel data = same as CIFAR-10
    _CUSTOM_RECORD_BYTES = 1 + 32*32*3
    
    # 1. Load binary records
    dataset = tf.data.FixedLengthRecordDataset(filenames, _CUSTOM_RECORD_BYTES)
    
    # 2. Parse each record into image and label
    def parse_record(raw_record):
        record = tf.io.decode_raw(raw_record, tf.uint8)
        # Extract label (first byte)
        label = tf.cast(record[0], tf.int32)
        # Extract and reshape image (remaining bytes)
        image = tf.reshape(record[1:], [32, 32, 3])
        image = tf.cast(image, tf.float32) / 255.0
        
        return image, label
    
    # 3. Apply parsing and augmentation
    dataset = dataset.map(parse_record, num_parallel_calls=tf.data.AUTOTUNE)
    
    if is_training:
        # Reuse the same data_augmentation function from Case 1
        def data_augmentation(image):
            image = tf.image.random_crop(image, [32, 32, 3])
            image = tf.image.random_flip_left_right(image)
            return image
        
        dataset = dataset.map(
            lambda image, label: (data_augmentation(image), label),
            num_parallel_calls=tf.data.AUTOTUNE
        )
    
    # 4. Finalize dataset
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset
3. Critical Checks for Compatibility
  • Output Structure: Double-check that each element in your dataset is a tuple of (image, label) with the correct shape and type (32x32x3 float32 image, int32 label).
  • Shuffling: If training, make sure to shuffle your dataset (add dataset = dataset.shuffle(buffer_size=10000) after parsing if needed).
  • Performance: Always use num_parallel_calls=tf.data.AUTOTUNE and prefetch to optimize data loading speed.

内容的提问来源于stack exchange,提问作者Elio

火山引擎 最新活动