如何批量写入TFRecords?处理4000万行CSV训练数据
这个问题问到点子上了!处理4000万行的超大CSV,逐个写入TFRecords确实效率拉胯,批量处理才是应对这种场景的正确姿势。我给你分享两种实用的实现方式,还有一些优化技巧:
方法一:用tf.data.Dataset实现高效批量解析与写入
TensorFlow的tf.data模块天生适合处理大规模数据集,它支持并行解析和批量处理,内存占用也更可控。下面是具体实现:
import tensorflow as tf def parse_csv_line(line): # 替换成你的CSV实际列的特征描述和默认值 record_defaults = [0.0, 0, 0] # 对应feature1(浮点)、feature2(整数)、label(整数) fields = tf.io.decode_csv(line, record_defaults=record_defaults) # 构建并序列化tf.train.Example example = tf.train.Example(features=tf.train.Features(feature={ 'feature1': tf.train.Feature(float_list=tf.train.FloatList(value=[fields[0]])), 'feature2': tf.train.Feature(int64_list=tf.train.Int64List(value=[fields[1]])), 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[fields[2]])) })) return example.SerializeToString() # 读取CSV并跳过表头 dataset = tf.data.TextLineDataset('your_large_data.csv').skip(1) # 并行解析每行数据(用AUTOTUNE自动适配CPU核心数) dataset = dataset.map(parse_csv_line, num_parallel_calls=tf.data.AUTOTUNE) # 设置批量大小,根据你的内存情况调整,比如1024或2048 batch_size = 1024 dataset = dataset.batch(batch_size) # 批量写入TFRecord with tf.io.TFRecordWriter('output.tfrecord') as writer: for batch in dataset: for serialized_example in batch: writer.write(serialized_example.numpy())
这种方式的优势在于:tf.data会自动处理数据的分片和加载,不会一次性把4000万行都读到内存里,而且并行解析能大幅提升处理速度。
方法二:手动攒批量(适合自定义解析逻辑)
如果你需要对CSV做一些TensorFlow解析函数不好处理的自定义操作(比如复杂的字符串处理),可以用Python原生的csv模块读取数据,手动攒够一批后再写入:
import csv import tensorflow as tf def create_tf_example(row): # 这里根据你的CSV列顺序,把字符串转成对应的数据类型 feature1 = float(row[0]) feature2 = int(row[1]) label = int(row[2]) # 序列化Example example = tf.train.Example(features=tf.train.Features(feature={ 'feature1': tf.train.Feature(float_list=tf.train.FloatList(value=[feature1])), 'feature2': tf.train.Feature(int64_list=tf.train.Int64List(value=[feature2])), 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])) })) return example.SerializeToString() batch_size = 1024 current_batch = [] with open('your_large_data.csv', 'r') as csv_file, tf.io.TFRecordWriter('output.tfrecord') as writer: csv_reader = csv.reader(csv_file) next(csv_reader) # 跳过表头行 for row in csv_reader: current_batch.append(create_tf_example(row)) # 攒够一批就写入 if len(current_batch) >= batch_size: for serialized_example in current_batch: writer.write(serialized_example) current_batch = [] # 处理最后一批不足batch_size的数据 if current_batch: for serialized_example in current_batch: writer.write(serialized_example)
这种方式更灵活,但要注意控制batch_size的大小,避免内存溢出。
额外优化技巧
- 多文件分片:4000万行数据可以分成多个TFRecord文件(比如每个文件存100万行),后续训练时可以用
tf.data.Dataset.list_files加载所有分片,实现并行读取,速度更快。 - 启用压缩:写入时可以启用GZIP压缩,节省磁盘空间,比如:
options = tf.io.TFRecordOptions(compression_type='GZIP') with tf.io.TFRecordWriter('output.tfrecord.gz', options=options) as writer: # 写入逻辑不变 - 并行处理:如果机器有多个CPU核心,可以用多进程分别处理CSV的不同分片,写入不同的TFRecord文件,进一步提升处理效率。
内容的提问来源于stack exchange,提问作者Insectatorious




