You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何将GCS多桶压缩CSV分批喂入TensorFlow多Worker训练模型?

解决GCS压缩CSV读取错误+多Worker分布式训练+数据转换方案

一、先搞定UnicodeDecodeError和压缩CSV读取问题

你遇到的UnicodeDecodeError多半是因为CSV文件不是UTF-8编码(比如GBK、Latin-1这类),或者make_csv_dataset没正确识别你的压缩格式。make_csv_dataset虽然用起来省心,但处理非标准编码/压缩的灵活性不如底层的tf.data.TextLineDataset。我推荐你用下面的方法来读取:

  1. 指定压缩格式与正确编码
    先确认你的压缩文件是gzip还是其他格式(比如bz2),然后用TextLineDataset配合压缩选项,同时指定匹配文件的编码:

    import tensorflow as tf
    
    # 把多个存储桶里的文件路径都列进来,或者用通配符自动匹配
    gcs_paths = ["gs://bucket1/file1.csv.gz", "gs://bucket2/file2.csv.gz"]
    
    # 读取压缩CSV,compression_type改成你实际的压缩格式,encoding根据文件编码调整
    dataset = tf.data.TextLineDataset(gcs_paths, compression_type="GZIP", encoding="latin-1")
    
    # 如果CSV有表头,记得跳过第一行
    dataset = dataset.skip(1)
    
  2. 自定义CSV解析逻辑
    tf.io.decode_csv手动解析每行数据,这样能灵活应对各种数据格式:

    # 假设你的CSV有3列:数值特征1、分类特征、标签(最后一列)
    # 给每列设置默认值,对应数据类型
    record_defaults = [tf.float32, tf.string, tf.float32]
    
    def parse_csv(line):
        # 解析每行,field_delim是CSV的分隔符,默认逗号,是其他的话改这里
        fields = tf.io.decode_csv(line, record_defaults=record_defaults, field_delim=",")
        # 拆分特征和标签
        features = tf.stack(fields[:-1])
        label = fields[-1]
        return features, label
    
    # 并行应用解析函数,提升效率
    dataset = dataset.map(parse_csv, num_parallel_calls=tf.data.AUTOTUNE)
    

二、添加数据转换预处理

在tf.data的pipeline里插入map操作就能完成各种数据转换,比如归一化、分类特征编码这些,直接看示例:

# 示例:归一化数值特征,编码分类特征
def preprocess(features, label):
    # 假设features[0]是需要归一化的数值特征,这里用提前算好的均值和方差
    normalized_feature = (features[0] - 10.0) / 5.0
    # 给分类特征(features[1])做字符串编码
    vocab = tf.constant(["cat", "dog", "bird"])
    table = tf.lookup.StaticVocabularyTable(
        tf.lookup.KeyValueTensorInitializer(vocab, tf.range(tf.size(vocab))),
        num_oov_buckets=1  # 处理不在词汇表里的未知类别
    )
    encoded_cat = table.lookup(features[1])
    # 拼接处理后的所有特征
    processed_features = tf.concat([[normalized_feature], [tf.cast(encoded_cat, tf.float32)]], axis=0)
    # 标签类型转换(如果需要的话)
    processed_label = tf.cast(label, tf.float32)
    return processed_features, processed_label

# 应用预处理,同样用并行调用提升速度
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

# 最后做shuffle、batch、prefetch优化训练性能
batch_size = 64
dataset = dataset.shuffle(10000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

三、多Worker分布式训练配置

tf.distribute.MultiWorkerMirroredStrategy就能实现多Worker分布式训练,步骤很清晰:

  1. 配置Worker集群
    每个Worker节点需要通过TF_CONFIG环境变量指定集群信息,比如:

    # Worker 0的环境变量设置
    export TF_CONFIG='{"cluster": {"worker": ["worker0.example.com:12345", "worker1.example.com:12345"]}, "task": {"type": "worker", "index": 0}}'
    
    # Worker 1的环境变量设置
    export TF_CONFIG='{"cluster": {"worker": ["worker0.example.com:12345", "worker1.example.com:12345"]}, "task": {"type": "worker", "index": 1}}'
    

    如果你在GCP的Vertex AI或者AI Platform上训练,平台会自动帮你配置好TF_CONFIG,不用手动操作。

  2. 在策略范围内构建模型
    所有和模型定义、编译相关的代码都要放在策略的作用域里:

    # 初始化分布式策略
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    
    with strategy.scope():
        # 定义你的Keras函数式模型,这里是个简单示例,你换成自己的模型就行
        input_layer = tf.keras.Input(shape=(2,))
        dense1 = tf.keras.layers.Dense(64, activation="relu")(input_layer)
        dense2 = tf.keras.layers.Dense(32, activation="relu")(dense1)
        output_layer = tf.keras.layers.Dense(1, activation="sigmoid")(dense2)
        model = tf.keras.Model(inputs=input_layer, outputs=output_layer)
    
        # 编译模型,选择合适的优化器、损失函数和评估指标
        model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
    
    # 启动训练,直接传入之前构建好的数据集就行
    model.fit(dataset, epochs=10)
    
  3. 数据集分发小提示
    多Worker模式下,tf.data会自动给每个Worker分配不同的数据分片,不用手动拆分数据。如果你的数据量特别大,也可以用strategy.distribute_datasets_from_function来更精细地控制数据分发,但一般默认的方式就足够用了。

额外小技巧

  • 不确定CSV编码的话,可以先在本地拿个样本文件用chardet检测:import chardet; print(chardet.detect(open("sample.csv", "rb").read())),得到编码后再在TensorFlow里指定。
  • 对于GCS上大量的文件,推荐用tf.data.Dataset.list_files自动匹配,比如tf.data.Dataset.list_files("gs://bucket*/data_*.csv.gz"),不用手动列所有文件路径。

内容的提问来源于stack exchange,提问作者Munisha Bansal

火山引擎 最新活动