You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

基于JAX和JAXOPT的优化代码运行缓慢的原因及性能优化建议

基于JAX和JAXOPT的优化代码运行缓慢的原因及性能优化建议

你好呀!看到你用JAX和JAXOPT做群体药代动力学模型优化时遇到了速度瓶颈——相同逻辑的代码在R中仅需3-4秒,但JAX版本跑了90秒,确实让人头疼。作为刚接触Python、JAX和JAXOPT的新手,咱们一步步来排查问题、优化代码~

你的问题背景

你正在实现一个基于JAX和JAXOPT的群体药代动力学模型优化,相同逻辑的代码在R中仅需3-4秒,但JAX版本耗时约90秒。你希望得到代码质量和运行速度的优化建议。

用到的输入数据集(THEOPH.csv)

ID,AMT,TIME,DV,WT,MDV,EVID
1,4.02,0,0.74,79.6,1,1
1,.,0.25,2.84,.,0,0
1,.,0.57,6.57,.,0,0
1,.,1.12,10.5,.,0,0
1,.,2.02,9.66,.,0,0
1,.,3.82,8.58,.,0,0
1,.,5.1,8.36,.,0,0
1,.,7.03,7.47,.,0,0
1,.,9.05,6.89,.,0,0
1,.,12.12,5.94,.,0,0
1,.,24.37,3.28,.,0,0
2,4.4,0,0,72.4,1,1
2,.,0.27,1.72,.,0,0
2,.,0.52,7.91,.,0,0
2,.,1,8.31,.,0,0
2,.,1.92,8.33,.,0,0
2,.,3.5,6.85,.,0,0
2,.,5.02,6.08,.,0,0
2,.,7.03,5.4,.,0,0
2,.,9,4.55,.,0,0
2,.,12,3.01,.,0,0
2,.,24.3,0.9,.,0,0
3,4.53,0,0,70.5,1,1
3,.,0.27,4.4,.,0,0
3,.,0.58,6.9,.,0,0
3,.,1.02,8.2,.,0,0
3,.,2.02,7.8,.,0,0
3,.,3.62,7.5,.,0,0
3,.,5.08,6.2,.,0,0
3,.,7.07,5.3,.,0,0
3,.,9,4.9,.,0,0
3,.,12.15,3.7,.,0,0
3,.,24.17,1.05,.,0,0
4,4.4,0,0,72.7,1,1
4,.,0.35,1.89,.,0,0
4,.,0.6,4.6,.,0,0
4,.,1.07,8.6,.,0,0
4,.,2.13,8.38,.,0,0
4,.,3.5,7.54,.,0,0
4,.,5.02,6.88,.,0,0
4,.,7.02,5.78,.,0,0
4,.,9.02,5.33,.,0,0
4,.,11.98,4.19,.,0,0
4,.,24.65,1.15,.,0,0
5,5.86,0,0,54.6,1,1
5,.,0.3,2.02,.,0,0
5,.,0.52,5.63,.,0,0
5,.,1,11.4,.,0,0
5,.,2.02,9.33,.,0,0
5,.,3.5,8.74,.,0,0
5,.,5.02,7.56,.,0,0
5,.,7.02,7.09,.,0,0
5,.,9.1,5.9,.,0,0
5,.,12,4.37,.,0,0
5,.,24.35,1.57,.,0,0
6,4.,0,0,80.,1,1
6,.,0.27,1.29,.,0,0
6,.,0.58,3.08,.,0,0
6,.,1.15,6.44,.,0,0
6,.,2.03,6.32,.,0,0
6,.,3.57,5.53,.,0,0
6,.,5,4.94,.,0,0
6,.,7,4.02,.,0,0
6,.,9.22,3.46,.,0,0
6,.,12.1,2.78,.,0,0
6,.,23.85,0.92,.,0,0
7,4.95,0,0.15,64.6,1,1
7,.,0.25,0.85,.,0,0
7,.,0.5,2.35,.,0,0
7,.,1.02,5.02,.,0,0
7,.,2.02,6.58,.,0,0
7,.,3.48,7.09,.,0,0
7,.,5,6.66,.,0,0
7,.,6.98,5.25,.,0,0
7,.,9,4.39,.,0,0
7,.,12.05,3.53,.,0,0
7,.,24.22,1.15,.,0,0
8,4.53,0,0,70.5,1,1
8,.,0.25,3.05,.,0,0
8,.,0.52,3.05,.,0,0
8,.,0.98,7.31,.,0,0
8,.,2.02,7.56,.,0,0
8,.,3.53,6.59,.,0,0
8,.,5.05,5.88,.,0,0
8,.,7.15,4.73,.,0,0
8,.,9.07,4.57,.,0,0
8,.,12.1,3,.,0,0
8,.,24.12,1.25,.,0,0
9,3.1,0,0,86.4,1,1
9,.,0.3,7.37,.,0,0
9,.,0.63,9.03,.,0,0
9,.,1.05,7.14,.,0,0
9,.,2.02,6.33,.,0,0
9,.,3.53,5.66,.,0,0
9,.,5.02,5.67,.,0,0
9,.,7.17,4.24,.,0,0
9,.,8.8,4.11,.,0,0
9,.,11.6,3.16,.,0,0
9,.,24.43,1.12,.,0,0
10,5.5,0,0.24,58.2,1,1
10,.,0.37,2.89,.,0,0
10,.,0.77,5.22,.,0,0
10,.,1.02,6.41,.,0,0
10,.,2.05,7.83,.,0,0
10,.,3.55,10.21,.,0,0
10,.,5.05,9.18,.,0,0
10,.,7.08,8.02,.,0,0
10,.,9.38,7.14,.,0,0
10,.,12.1,5.68,.,0,0
10,.,23.7,2.42,.,0,0
11,4.92,0,0,65.,1,1
11,.,0.25,4.86,.,0,0
11,.,0.5,7.24,.,0,0
11,.,0.98,8,.,0,0
11,.,1.98,6.81,.,0,0
11,.,3.6,5.87,.,0,0
11,.,5.02,5.22,.,0,0
11,.,7.03,4.45,.,0,0
11,.,9.03,3.62,.,0,0
11,.,12.12,2.69,.,0,0
11,.,24.08,0.86,.,0,0
12,5.3,0,0,60.5,1,1
12,.,0.25,1.25,.,0,0
12,.,0.5,3.96,.,0,0
12,.,1,7.82,.,0,0
12,.,2,9.72,.,0,0
12,.,3.52,9.75,.,0,0
12,.,5.07,8.57,.,0,0
12,.,7.07,6.59,.,0,0
12,.,9.03,6.11,.,0,0
12,.,12.05,4.57,.,0,0
12,.,24.15,1.17,.,0,0

你当前的JAX/JAXOPT代码

import jax
import jax.numpy as jnp
import pandas as pd
import jaxopt as jaxopt
import time

# Enable 64-bit precision
jax.config.update("jax_enable_x64", True)

# Calculate the sqrt of the inverse of a matrix
@jax.jit
def mat_sqrt_inv(mat):
    eigvals, eigvecs = jnp.linalg.eigh(mat)
    d2 = 1.0 / jnp.sqrt(jnp.abs(eigvals))
    return eigvecs @ jnp.diag(d2) @ eigvecs.T

# Calculate the PRED value
@jax.jit
def PRED(THETA, ETA, DATA):
    DOSE = 320.0
    TIME = DATA[:, 1] # TIME values
    KA = THETA[0] * jnp.exp(ETA[0])
    V = THETA[1] * jnp.exp(ETA[1])
    K = THETA[2] * jnp.exp(ETA[2])
    F = DOSE / V * KA / (KA - K) * (jnp.exp(-K * TIME) - jnp.exp(-KA * TIME))
    G1 = ((DOSE / V / (KA - K) - DOSE / V * KA / (KA - K) ** 2) * (jnp.exp(-K * TIME) - jnp.exp(-KA * TIME)) + DOSE / V * KA / (KA - K) * (jnp.exp(-KA * TIME) * TIME)) * KA
    G2 = -(DOSE / V ** 2 * KA / (KA - K) * (jnp.exp(-K * TIME) - jnp.exp(-KA * TIME))) * V
    G3 = (DOSE / V * KA / (KA - K) ** 2 * (jnp.exp(-K * TIME) - jnp.exp(-KA * TIME)) - DOSE / V * KA / (KA - K) * (jnp.exp(-K * TIME) * TIME)) * K
    H1 = F
    H2 = jnp.ones_like(F)
    return jnp.stack([F, G1, G2, G3, H1, H2], axis=1)

# Read data
df = pd.read_csv("THEOPH.csv")
df = df[df['EVID'] == 0]
ID = df['ID'].unique()
NETA = 3
NEPS = 2
selected_columns = ['ID', 'TIME', 'DV']
DATASET = jnp.array(df[selected_columns].values)

_subject_data_list = []
for subject_id_val in sorted(ID):
    # Iterate through unique IDs in sorted order
    _subject_data_list.append(DATASET[DATASET[:, 0] == subject_id_val])

# Final NONMEM estimates
THETA = jnp.array([2.8864797758451806, 33.693922520723312, 8.7260954693777815E-002])
OMEGA = jnp.array([
 [0.88639937186299267, 0, 0],
 [0, 2.0903952408947064E-002, 0],
 [0, 0, 6.8919408612682198E-002]
 ])
SIGMA = jnp.array([
 [9.9259732296116312E-003, 0],
 [0, 3.2931599977662498E-002]
 ])

# Objective function (补充完整群体模型对数似然逻辑)
def obj_fn(params, subject_data_list):
    THETA, OMEGA, SIGMA = params
    total_loglik = 0.0

    for subject_data in subject_data_list:
        DV = subject_data[:, 2]
        pred_matrix = PRED(THETA, jnp.zeros(NETA), subject_data)
        pred = pred_matrix[:, 0]
        res = DV - pred
        # 简化的个体对数似然计算
        loglik_ind = -0.5 * (jnp.dot(res, res) / SIGMA[0,0] + len(res)*jnp.log(2*jnp.pi*SIGMA[0,0]))
        total_loglik += loglik_ind

    return -total_loglik  # JAXOPT默认最小化,返回负对数似然

# 优化器初始化
lbfgsb = jaxopt.LBFGSB(fun=obj_fn, maxiter=100)
params_init = (THETA, OMEGA, SIGMA)

# 运行优化并计时
start_time = time.time()
result = lbfgsb.run(params_init, subject_data_list=_subject_data_list)
end_time = time.time()

print(f"Optimization took {end_time - start_time:.2f} seconds")
print(f"Optimized THETA: {result.params[0]}")

为什么你的JAX代码这么慢?核心原因

  1. JIT编译用得不全:你给mat_sqrt_invPRED加了@jax.jit,但最耗时的obj_fn没加!JAX的性能优势完全依赖XLA编译,核心函数不JIT等于白用JAX。
  2. Python循环拖垮性能obj_fn里用Python循环遍历每个受试者,JIT无法编译Python循环,每次迭代都要回到Python解释器,完全浪费了JAX的向量化能力。
  3. 数据结构不适合JAX:用Python列表存每个受试者的数据,JAX很难对其做向量化处理,更适合用统一形状的张量。
  4. 重复计算冗余PRED函数里多次重复计算jnp.exp(-K * TIME)ka_minus_k等项,增加了不必要的运算量。

具体优化建议

1. 用JIT+Vmap替代Python循环

这是最核心的优化!用jax.vmap把单受试者处理逻辑自动映射到所有受试者,再给obj_fn@jax.jit

# 先定义单受试者似然计算
@jax.jit
def single_subject_loglik(THETA, OMEGA,

火山引擎 最新活动