如何在Flax NNX训练流程中用Optax调度器控制模型调用的超参数并收集其值?
如何在Flax NNX训练流程中用Optax调度器控制模型调用的超参数并收集其值?
我完全理解你的需求:要让Optax调度器生成的动态超参数传入模型的前向传播逻辑,同时还要把每一步的超参数值完整收集下来。咱们直接从核心修改点入手,先给你完整的可运行代码,再逐部分解释关键调整:
完整可运行代码
from jax import numpy as jnp from jax import random from flax import nnx import optax from matplotlib import pyplot as plt if __name__ == '__main__': shape = (2, 55, 1) epochs = 123 rngs = nnx.Rngs(123) # 1. 自定义支持超参数输入的模型 class HyperLinear(nnx.Module): def __init__(self, in_features: int, out_features: int, rngs: nnx.Rngs): self.linear = nnx.Linear(in_features, out_features, rngs=rngs) def __call__(self, hyperparam: jnp.ndarray, inputs: jnp.ndarray) -> jnp.ndarray: # 可自定义超参数的作用逻辑,这里示例为缩放线性层输出 base_pred = self.linear(inputs) return hyperparam * base_pred # 初始化模型 model = HyperLinear(1, 1, rngs=rngs) skey = rngs.params() xx = random.uniform(skey, shape, minval=-10, maxval=10) # 生成观测数据 def f(x, m=2.234, b=-1.123): return m * x + b obs1, obs2 = f(xx), f(xx) x1, x2 = xx, xx c = 0.9 # 学习率调度器 learning_rate_schedule = optax.schedules.cosine_decay_schedule( init_value=2e-1, decay_steps=int(c * epochs), alpha=0.01, ) # 超参数调度器:从12线性增长到234 hyperparam_schedule = optax.schedules.linear_schedule( init_value=12, end_value=234, transition_steps=int(c * epochs), ) # 获取模型可训练参数 params = nnx.filter(model, nnx.Param) # 初始化优化器 optimizer = nnx.Optimizer( model, tx=optax.adam(learning_rate_schedule), wrt=params ) # 2. 定义带超参数的扫描训练循环 @nnx.scan( in_axes=(nnx.Carry, None, None, 0), # step按序列传入每一步 out_axes=(nnx.Carry, 0, 0, 0), # 输出loss、mae、超参数值 length=epochs ) def optimizer_scan(carry, x, obs, step): model, optimizer = carry # 计算当前step对应的超参数值 current_hyperparam = hyperparam_schedule(step) # 定义包含超参数的损失函数 def loss_function(model, inputs, obs, hyperparam): prediction = model(hyperparam, inputs) # 超参数传入模型 error = obs - prediction loss = jnp.mean(error ** 2) mae = jnp.mean(jnp.abs(error)) # 返回损失+辅助信息(mae+当前超参数) return loss, (mae, current_hyperparam) # 计算损失和梯度 (loss, (mae, current_hyperparam)), grads = nnx.value_and_grad( loss_function, has_aux=True )(model, x, obs, current_hyperparam) # 更新模型参数 optimizer.update(model, grads) return (model, optimizer), (loss, mae, current_hyperparam) # 生成训练步数序列,用于超参数计算 steps = jnp.arange(epochs) # 3. 运行训练循环,收集所有结果 (model, optimizer), (losses, maes, hyperparams) = optimizer_scan( (model, optimizer), x1, obs1, steps ) # 打印训练结果 print('AFTER TRAINING') print(f'最终训练MSE损失: {losses[-1]:.4f}') print(f'初始超参数值: {hyperparams[0]:.2f}') print(f'最终超参数值: {hyperparams[-1]:.2f}') # 测试集验证 test_pred = model(hyperparam_schedule(epochs-1), x2) test_error = obs2 - test_pred test_loss = jnp.mean(test_error ** 2) print(f'测试集MSE损失: {test_loss:.4f}') print(f'模型权重近似值(m): {model.linear.kernel.value.item():.4f}') print(f'模型偏置近似值(b): {model.linear.bias.value.item():.4f}') # 可选:绘制超参数与损失变化曲线 plt.figure(figsize=(12,4)) plt.subplot(1,2,1) plt.plot(hyperparams) plt.title('超参数随训练步数变化') plt.xlabel('训练轮次') plt.ylabel('超参数值') plt.subplot(1,2,2) plt.plot(losses, label='MSE损失') plt.plot(maes, label='MAE') plt.title('损失随训练步数变化') plt.xlabel('训练轮次') plt.ylabel('损失值') plt.legend() plt.tight_layout() plt.show()
关键修改点详解
1. 自定义支持超参数的模型
原来的nnx.Linear不支持额外输入超参数,我们自定义HyperLinear模块,让它的__call__方法接收超参数并参与前向计算。你可以根据自己的需求修改超参数的作用逻辑(比如缩放权重、调整偏置等),示例中是用超参数缩放线性层的输出。
2. 在扫描循环中动态获取超参数
Optax调度器本质是一个输入step返回对应值的函数,我们通过jnp.arange(epochs)生成训练步数序列,把它作为输入传入nnx.scan循环,每一步都能计算出当前的超参数值。
3. 调整扫描循环的输入输出,收集超参数
- 修改
nnx.scan的in_axes,让步数序列按每一步传入循环; - 修改
out_axes,让超参数值和loss、mae一起被收集为数组; - 在损失函数中把超参数作为辅助信息返回,确保梯度计算时不会影响参数更新(超参数是调度生成的,不需要求导)。
4. 修正优化器参数追踪
用nnx.filter(model, nnx.Param)准确获取模型的可训练参数,避免原代码中params = nnx.Param的模糊定义,确保优化器只更新模型的可训练参数,不会影响超参数。




