You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

如何批量写入TFRecords?处理4000万行CSV训练数据

这个问题问到点子上了!处理4000万行的超大CSV,逐个写入TFRecords确实效率拉胯,批量处理才是应对这种场景的正确姿势。我给你分享两种实用的实现方式,还有一些优化技巧:

方法一:用tf.data.Dataset实现高效批量解析与写入

TensorFlow的tf.data模块天生适合处理大规模数据集,它支持并行解析和批量处理,内存占用也更可控。下面是具体实现:

import tensorflow as tf

def parse_csv_line(line):
    # 替换成你的CSV实际列的特征描述和默认值
    record_defaults = [0.0, 0, 0]  # 对应feature1(浮点)、feature2(整数)、label(整数)
    fields = tf.io.decode_csv(line, record_defaults=record_defaults)
    
    # 构建并序列化tf.train.Example
    example = tf.train.Example(features=tf.train.Features(feature={
        'feature1': tf.train.Feature(float_list=tf.train.FloatList(value=[fields[0]])),
        'feature2': tf.train.Feature(int64_list=tf.train.Int64List(value=[fields[1]])),
        'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[fields[2]]))
    }))
    return example.SerializeToString()

# 读取CSV并跳过表头
dataset = tf.data.TextLineDataset('your_large_data.csv').skip(1)
# 并行解析每行数据(用AUTOTUNE自动适配CPU核心数)
dataset = dataset.map(parse_csv_line, num_parallel_calls=tf.data.AUTOTUNE)
# 设置批量大小,根据你的内存情况调整,比如1024或2048
batch_size = 1024
dataset = dataset.batch(batch_size)

# 批量写入TFRecord
with tf.io.TFRecordWriter('output.tfrecord') as writer:
    for batch in dataset:
        for serialized_example in batch:
            writer.write(serialized_example.numpy())

这种方式的优势在于:tf.data会自动处理数据的分片和加载,不会一次性把4000万行都读到内存里,而且并行解析能大幅提升处理速度。

方法二:手动攒批量(适合自定义解析逻辑)

如果你需要对CSV做一些TensorFlow解析函数不好处理的自定义操作(比如复杂的字符串处理),可以用Python原生的csv模块读取数据,手动攒够一批后再写入:

import csv
import tensorflow as tf

def create_tf_example(row):
    # 这里根据你的CSV列顺序,把字符串转成对应的数据类型
    feature1 = float(row[0])
    feature2 = int(row[1])
    label = int(row[2])
    
    # 序列化Example
    example = tf.train.Example(features=tf.train.Features(feature={
        'feature1': tf.train.Feature(float_list=tf.train.FloatList(value=[feature1])),
        'feature2': tf.train.Feature(int64_list=tf.train.Int64List(value=[feature2])),
        'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
    }))
    return example.SerializeToString()

batch_size = 1024
current_batch = []

with open('your_large_data.csv', 'r') as csv_file, tf.io.TFRecordWriter('output.tfrecord') as writer:
    csv_reader = csv.reader(csv_file)
    next(csv_reader)  # 跳过表头行
    for row in csv_reader:
        current_batch.append(create_tf_example(row))
        # 攒够一批就写入
        if len(current_batch) >= batch_size:
            for serialized_example in current_batch:
                writer.write(serialized_example)
            current_batch = []
    # 处理最后一批不足batch_size的数据
    if current_batch:
        for serialized_example in current_batch:
            writer.write(serialized_example)

这种方式更灵活,但要注意控制batch_size的大小,避免内存溢出。

额外优化技巧
  • 多文件分片:4000万行数据可以分成多个TFRecord文件(比如每个文件存100万行),后续训练时可以用tf.data.Dataset.list_files加载所有分片,实现并行读取,速度更快。
  • 启用压缩:写入时可以启用GZIP压缩,节省磁盘空间,比如:
    options = tf.io.TFRecordOptions(compression_type='GZIP')
    with tf.io.TFRecordWriter('output.tfrecord.gz', options=options) as writer:
        # 写入逻辑不变
    
  • 并行处理:如果机器有多个CPU核心,可以用多进程分别处理CSV的不同分片,写入不同的TFRecord文件,进一步提升处理效率。

内容的提问来源于stack exchange,提问作者Insectatorious

火山引擎 最新活动