You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何循环调用Keras fit函数?解决训练指标全为NaN/0的异常

解决多数据集循环训练的NaN/0指标异常问题

我来帮你拆解下问题核心,然后一步步修复这个循环训练的异常情况:

一、先定位代码里的致命问题

1. 数据清洗完全没生效

你写的data['open'].replace('nan',np.nan)data.dropna()都是无效操作——pandas的这些方法默认返回新的DataFrame,不会修改原数据。脏数据(比如字符串型的'nan')会直接进入后续流程,导致归一化时出现NaN,最终让模型训练崩溃。

2. 数据集初始化路径错误

你用x_train = [None] * len(os.listdir('/'))初始化列表,这里的根目录/会让列表长度远大于实际数据集数量,后续循环时会处理大量None值,训练自然出问题。应该换成你存放数据集的真实目录/directory

3. 宽泛的异常掩盖了所有错误

except: pass会跳过所有错误,比如文件读取失败、数据格式错误,你根本不知道哪些数据集处理出了问题,甚至可能把无效的None传入model.fit,直接导致指标异常。

4. 模型没有重置权重

循环训练时你用的是同一个模型实例,首次训练因为数据问题产生NaN后,模型权重已经变成NaN,后续训练只会在这个烂摊子上继续更新,所有指标自然全是NaN或0。


二、修正后的完整代码示例

import os
import numpy as np
import pandas as pd
# 假设你的模型定义在这里,以LSTM为例
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense

def create_model(input_shape):
    # 每次训练新数据集前,创建全新模型重置权重
    model = Sequential()
    model.add(LSTM(50, input_shape=input_shape))
    model.add(Dense(1))
    model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
    return model

def normalise_windows(window_data):
    # 给归一化加保护,避免除以0导致NaN
    normalised_data = []
    for window in window_data:
        diffs = window - window[0]
        std = window.std()
        if std == 0:
            # 窗口内数据全相同,用0填充避免NaN
            normalised_data.append(np.zeros_like(diffs))
        else:
            normalised_data.append(diffs / std)
    return np.array(normalised_data)

def format_data(data, seq_len, normalise_window):
    sequence_length = seq_len + 1
    # 先检查数据长度是否足够生成序列
    if len(data) <= sequence_length:
        return None, None, None, None
    result = []
    for index in range(len(data) - sequence_length):
        result.append(data[index: index + sequence_length])
    if normalise_window:
        result = normalise_windows(result)
    result = np.array(result)
    row = round(0.85 * result.shape[0])
    # 确保训练/测试集都有足够样本
    if row < 1 or (result.shape[0] - row) < 1:
        return None, None, None, None
    train = result[:int(row), :]
    np.random.shuffle(train)
    x_train = train[:, :-1]
    y_train = train[:, -1]
    x_test = result[int(row):, :-1]
    y_test = result[int(row):, -1]
    x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))
    x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1))
    return [x_train, y_train, x_test, y_test]

# 初始化数据集列表,用目标目录的真实文件数
data_dir = '/directory'
file_list = os.listdir(data_dir)
x_train = [None] * len(file_list)
y_train = [None] * len(file_list)
x_test = [None] * len(file_list)
y_test = [None] * len(file_list)

seq_len = 20  # 替换成你的实际序列长度
batch = 32
epochs = 10

# 处理每个数据集
for i, filename in enumerate(file_list):
    print(f"开始处理文件: {filename}")
    try:
        data_path = os.path.join(data_dir, filename)
        data = pd.read_csv(data_path, index_col=0, header=0)
        # 修复数据清洗:赋值回原变量,只删除目标列的NaN行
        data['open'] = data['open'].replace('nan', np.nan)
        data = data.dropna(subset=['open'])
        if len(data['open']) == 0:
            print(f"文件 {filename} 无有效数据,跳过")
            continue
        new = data['open'].tolist()
        # 检查处理后的数据是否有效
        xt, yt, xte, yte = format_data(new, seq_len, True)
        if xt is None:
            print(f"文件 {filename} 数据量不足,跳过")
            continue
        x_train[i] = xt
        y_train[i] = yt
        x_test[i] = xte
        y_test[i] = yte
    except Exception as e:
        # 打印具体错误,方便排查
        print(f"处理文件 {filename} 出错: {str(e)}")
        continue

# 循环训练每个有效数据集
for i in range(len(x_train)):
    xt = x_train[i]
    yt = y_train[i]
    if xt is None or yt is None:
        print(f"跳过无效数据集索引 {i}")
        continue
    print(f"开始训练: {i+1}/{len(x_train)}")
    # 每次训练创建全新模型
    model = create_model((xt.shape[1], 1))
    try:
        # 数据集太小时关闭验证集,避免无样本的情况
        val_split = 0.05 if xt.shape[0] >= 20 else 0.0
        model.fit(xt, yt, batch_size=batch, epochs=epochs, validation_split=val_split)
    except Exception as e:
        print(f"训练数据集 {i} 出错: {str(e)}")
        continue

关键改进点说明

  1. 数据清洗生效:把replacedropna的结果赋值回原数据,精准删除目标列的NaN行,避免脏数据流入训练。
  2. 模型重置:用create_model函数每次训练前生成全新模型,确保每个数据集的训练都是从初始权重开始。
  3. 错误可视化:替换宽泛的except: pass,打印具体异常信息,快速定位问题文件。
  4. 数据有效性检查:在format_data里提前拦截长度不足的数据集,归一化时处理全相同数据的情况,防止除以0生成NaN。
  5. 验证集保护:数据集过小时自动关闭validation_split,避免出现无验证样本导致的指标异常。

内容的提问来源于stack exchange,提问作者Carson P

火山引擎 最新活动