You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

Temporal Fusion Transformer是否学习时序全局趋势及训练验证问题

使用Temporal Fusion Transformer(TFT)做实时工艺预测的技术疑问

我正在用Temporal Fusion Transformer(TFT)训练时序数据,目标是给某工艺单元的运行过程提供任意时间点的实时预测。但训练和验证用的都是已完成工艺批次的完整历史时序数据(每个样本覆盖工艺全周期),基于这个设置,我有三个技术疑问:

  • 模型是否能学习到每个时序组的整体/全局趋势?也就是在实时运行时输入中间阶段的数据做预测,尽管模型是用全范围样本训练的,能不能依然表现良好?
  • 训练过程中,验证是只针对每个时序的最后解码器片段,还是覆盖整个时序范围?
  • 验证数据集应该包含未见过的组(即和训练数据不同的组ID),还是可以复用训练数据里的组ID?

简化的数据集配置代码

max_prediction_length = 60
max_encoder_length = 300
training_cutoff = df["time_idx"].max() - max_prediction_length

training = TimeSeriesDataSet(
    df_train,
    time_idx="time_idx",
    target="", 
    group_ids=[""],
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=[],
    static_reals=[], 
    time_varying_known_categoricals=[],
    variable_groups={},  
    time_varying_known_reals=[], 
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[], 
    allow_missing_timesteps=True,
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

validation = TimeSeriesDataSet.from_dataset(
    training, df_valid, predict=False, stop_randomization=True
)

# create dataloaders
batch_size = 128
train_dataloader = training.to_dataloader(
    train=True, batch_size=batch_size, num_workers=0
)
val_dataloader = validation.to_dataloader(
    train=False, batch_size=batch_size * 10, num_workers=0
)

问题解答

1. 模型对全局趋势的学习能力与中间阶段预测表现

TFT的门控机制和注意力模块天生擅长捕捉时序数据的全局趋势和局部阶段模式。用全周期批次数据训练时,模型会学习到不同批次工艺各阶段(启动、稳定、收尾)的典型特征。

实时输入中间阶段数据预测时,只要训练数据覆盖了足够多的典型中间阶段场景,模型通常能稳定输出结果。你设置的min_encoder_lengthmax_encoder_length的一半,训练时模型会接触到150-300步的不同长度编码器输入,这种设置能帮助模型适应实时场景中不同阶段的输入长度,对中间阶段预测是有利的。但如果工艺中间阶段存在训练样本中未覆盖的异常模式,模型预测可能会出现偏差。

2. 验证集的采样范围

从你的代码来看,验证集通过TimeSeriesDataSet.from_dataset(..., predict=False, stop_randomization=True)创建,此时验证集的采样逻辑和训练集类似,并非只针对每个时序的最后解码器片段

  • 训练时train=True的dataloader会随机采样150-300步的编码器片段,对应1-60步的解码器片段;
  • 验证时stop_randomization=True会固定采样方式,但依然会覆盖时序的不同位置,而非仅取最后60步的解码器片段。

如果需要验证仅针对每个时序的最后一段,你需要手动设置验证集的training_cutoff,或者在创建验证集时限制采样范围为每个批次的末尾部分。

3. 验证数据集的组ID选择

这取决于你的实时预测场景需求:

  • 若针对全新批次预测:验证集必须包含未见过的组ID,这样验证结果才能真实反映模型在部署时的泛化能力,毕竟不同工艺批次可能存在批次间差异;
  • 若针对已存在批次的中间阶段预测:可以复用训练数据中的组ID,但要注意验证采样的片段不能和训练集重叠(比如训练用批次前N步,验证用后M步),避免数据泄露。

从工业实时预测的通用场景来看,更推荐用未见过的组ID做验证,能更准确地测试模型应对新批次的能力。


内容的提问来源于stack exchange,提问作者YoungJoo Park

火山引擎 最新活动