You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

如何加载多份npz格式训练与验证数据到Keras模型并完成训练验证?

我看了你的自定义生成器代码,发现几个关键问题导致它无法正常工作,咱们一步步来修复:

问题分析

  1. 没有正确构建批次数据:你的for循环里每次读取一个文件就直接覆盖XY,最终只返回最后一个文件的数据,完全没有把batch_size个文件的数据拼接成完整批次。
  2. 循环逻辑有漏洞:当i*batch_size >= len(file_list)时你重置了索引并打乱列表,但这时候没有处理当前批次,会直接跳过一轮数据生成。
  3. 缺少批次维度扩展:单个文件的X/Y形状是(1024,28),但Keras模型需要的批次输入应该带有批次维度,也就是(batch_size, 1024, 28)

修正后的生成器代码

import numpy as np

def tf_train_generator(file_list, batch_size=1):
    i = 0
    # 初始先打乱一次文件列表,保证训练起始的随机性
    np.random.shuffle(file_list)
    
    while True:
        start = i * batch_size
        end = start + batch_size
        
        # 当批次超出文件列表长度时,重置索引并重新打乱
        if end > len(file_list):
            i = 0
            np.random.shuffle(file_list)
            start = 0
            end = batch_size
        
        file_chunk = file_list[start:end]
        
        # 初始化批次数据容器
        batch_X = []
        batch_Y = []
        
        # 遍历批次内的文件,收集数据
        for file_path in file_chunk:
            temp = np.load(file_path)
            batch_X.append(temp['data1'])
            batch_Y.append(temp['data2'])
        
        # 转换为numpy数组,自动添加批次维度
        batch_X = np.array(batch_X)  # 最终形状: (batch_size, 1024, 28)
        batch_Y = np.array(batch_Y)
        
        i += 1
        yield batch_X, batch_Y

代码改进说明

  • 用列表batch_Xbatch_Y收集批次内所有文件的数据,最后统一转为numpy数组,保证输出是完整的批次数据。
  • 修正了循环逻辑,当批次超出文件范围时直接重置索引并重新取批次,不会跳过数据。
  • 初始时先打乱文件列表,避免训练起始阶段数据顺序固定。

生成器使用示例

import os

# 获取训练/验证集文件路径列表
train_files = [os.path.join('train_data', f) for f in os.listdir('train_data') if f.endswith('.npz')]
valid_files = [os.path.join('valid_data', f) for f in os.listdir('valid_data') if f.endswith('.npz')]

# 创建生成器
train_generator = tf_train_generator(train_files, batch_size=8)
valid_generator = tf_train_generator(valid_files, batch_size=8)

# 假设你已经定义好Keras模型,训练时这样调用
# model.fit(
#     train_generator,
#     steps_per_epoch=len(train_files) // 8,  # 每个epoch的步数
#     validation_data=valid_generator,
#     validation_steps=len(valid_files) // 8,
#     epochs=10
# )

如果你的文件总数无法被batch_size整除,最后一个批次的大小会小于设定值,Keras默认会正常处理这种情况;如果需要严格固定批次大小,可以在重置索引时额外处理剩余文件。

内容的提问来源于stack exchange,提问作者user13823202

火山引擎 最新活动