You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何在Flax NNX训练流程中用Optax调度器控制模型调用的超参数并收集其值?

如何在Flax NNX训练流程中用Optax调度器控制模型调用的超参数并收集其值?

我完全理解你的需求:要让Optax调度器生成的动态超参数传入模型的前向传播逻辑,同时还要把每一步的超参数值完整收集下来。咱们直接从核心修改点入手,先给你完整的可运行代码,再逐部分解释关键调整:


完整可运行代码

from jax import numpy as jnp
from jax import random
from flax import nnx
import optax
from matplotlib import pyplot as plt

if __name__ == '__main__':
    shape = (2, 55, 1)
    epochs = 123
    rngs = nnx.Rngs(123)

    # 1. 自定义支持超参数输入的模型
    class HyperLinear(nnx.Module):
        def __init__(self, in_features: int, out_features: int, rngs: nnx.Rngs):
            self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
        
        def __call__(self, hyperparam: jnp.ndarray, inputs: jnp.ndarray) -> jnp.ndarray:
            # 可自定义超参数的作用逻辑,这里示例为缩放线性层输出
            base_pred = self.linear(inputs)
            return hyperparam * base_pred

    # 初始化模型
    model = HyperLinear(1, 1, rngs=rngs)

    skey = rngs.params()
    xx = random.uniform(skey, shape, minval=-10, maxval=10)

    # 生成观测数据
    def f(x, m=2.234, b=-1.123):
        return m * x + b
    obs1, obs2 = f(xx), f(xx)
    x1, x2 = xx, xx

    c = 0.9

    # 学习率调度器
    learning_rate_schedule = optax.schedules.cosine_decay_schedule(
        init_value=2e-1,
        decay_steps=int(c * epochs),
        alpha=0.01,
    )

    # 超参数调度器:从12线性增长到234
    hyperparam_schedule = optax.schedules.linear_schedule(
        init_value=12,
        end_value=234,
        transition_steps=int(c * epochs),
    )

    # 获取模型可训练参数
    params = nnx.filter(model, nnx.Param)
    # 初始化优化器
    optimizer = nnx.Optimizer(
        model,
        tx=optax.adam(learning_rate_schedule),
        wrt=params
    )

    # 2. 定义带超参数的扫描训练循环
    @nnx.scan(
        in_axes=(nnx.Carry, None, None, 0),  # step按序列传入每一步
        out_axes=(nnx.Carry, 0, 0, 0),  # 输出loss、mae、超参数值
        length=epochs
    )
    def optimizer_scan(carry, x, obs, step):
        model, optimizer = carry
        # 计算当前step对应的超参数值
        current_hyperparam = hyperparam_schedule(step)
        
        # 定义包含超参数的损失函数
        def loss_function(model, inputs, obs, hyperparam):
            prediction = model(hyperparam, inputs)  # 超参数传入模型
            error = obs - prediction
            loss = jnp.mean(error ** 2)
            mae = jnp.mean(jnp.abs(error))
            # 返回损失+辅助信息(mae+当前超参数)
            return loss, (mae, current_hyperparam)
        
        # 计算损失和梯度
        (loss, (mae, current_hyperparam)), grads = nnx.value_and_grad(
            loss_function, has_aux=True
        )(model, x, obs, current_hyperparam)
        
        # 更新模型参数
        optimizer.update(model, grads)
        
        return (model, optimizer), (loss, mae, current_hyperparam)

    # 生成训练步数序列,用于超参数计算
    steps = jnp.arange(epochs)

    # 3. 运行训练循环,收集所有结果
    (model, optimizer), (losses, maes, hyperparams) = optimizer_scan(
        (model, optimizer), x1, obs1, steps
    )

    # 打印训练结果
    print('AFTER TRAINING')
    print(f'最终训练MSE损失: {losses[-1]:.4f}')
    print(f'初始超参数值: {hyperparams[0]:.2f}')
    print(f'最终超参数值: {hyperparams[-1]:.2f}')

    # 测试集验证
    test_pred = model(hyperparam_schedule(epochs-1), x2)
    test_error = obs2 - test_pred
    test_loss = jnp.mean(test_error ** 2)
    print(f'测试集MSE损失: {test_loss:.4f}')
    print(f'模型权重近似值(m): {model.linear.kernel.value.item():.4f}')
    print(f'模型偏置近似值(b): {model.linear.bias.value.item():.4f}')

    # 可选:绘制超参数与损失变化曲线
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(hyperparams)
    plt.title('超参数随训练步数变化')
    plt.xlabel('训练轮次')
    plt.ylabel('超参数值')

    plt.subplot(1,2,2)
    plt.plot(losses, label='MSE损失')
    plt.plot(maes, label='MAE')
    plt.title('损失随训练步数变化')
    plt.xlabel('训练轮次')
    plt.ylabel('损失值')
    plt.legend()
    plt.tight_layout()
    plt.show()

关键修改点详解

1. 自定义支持超参数的模型

原来的nnx.Linear不支持额外输入超参数,我们自定义HyperLinear模块,让它的__call__方法接收超参数并参与前向计算。你可以根据自己的需求修改超参数的作用逻辑(比如缩放权重、调整偏置等),示例中是用超参数缩放线性层的输出。

2. 在扫描循环中动态获取超参数

Optax调度器本质是一个输入step返回对应值的函数,我们通过jnp.arange(epochs)生成训练步数序列,把它作为输入传入nnx.scan循环,每一步都能计算出当前的超参数值。

3. 调整扫描循环的输入输出,收集超参数

  • 修改nnx.scanin_axes,让步数序列按每一步传入循环;
  • 修改out_axes,让超参数值和loss、mae一起被收集为数组;
  • 在损失函数中把超参数作为辅助信息返回,确保梯度计算时不会影响参数更新(超参数是调度生成的,不需要求导)。

4. 修正优化器参数追踪

nnx.filter(model, nnx.Param)准确获取模型的可训练参数,避免原代码中params = nnx.Param的模糊定义,确保优化器只更新模型的可训练参数,不会影响超参数。

火山引擎 最新活动