基于JAX和JAXOPT的优化代码运行缓慢的原因及性能优化建议
基于JAX和JAXOPT的优化代码运行缓慢的原因及性能优化建议
你好呀!看到你用JAX和JAXOPT做群体药代动力学模型优化时遇到了速度瓶颈——相同逻辑的代码在R中仅需3-4秒,但JAX版本跑了90秒,确实让人头疼。作为刚接触Python、JAX和JAXOPT的新手,咱们一步步来排查问题、优化代码~
你的问题背景
你正在实现一个基于JAX和JAXOPT的群体药代动力学模型优化,相同逻辑的代码在R中仅需3-4秒,但JAX版本耗时约90秒。你希望得到代码质量和运行速度的优化建议。
用到的输入数据集(THEOPH.csv)
ID,AMT,TIME,DV,WT,MDV,EVID 1,4.02,0,0.74,79.6,1,1 1,.,0.25,2.84,.,0,0 1,.,0.57,6.57,.,0,0 1,.,1.12,10.5,.,0,0 1,.,2.02,9.66,.,0,0 1,.,3.82,8.58,.,0,0 1,.,5.1,8.36,.,0,0 1,.,7.03,7.47,.,0,0 1,.,9.05,6.89,.,0,0 1,.,12.12,5.94,.,0,0 1,.,24.37,3.28,.,0,0 2,4.4,0,0,72.4,1,1 2,.,0.27,1.72,.,0,0 2,.,0.52,7.91,.,0,0 2,.,1,8.31,.,0,0 2,.,1.92,8.33,.,0,0 2,.,3.5,6.85,.,0,0 2,.,5.02,6.08,.,0,0 2,.,7.03,5.4,.,0,0 2,.,9,4.55,.,0,0 2,.,12,3.01,.,0,0 2,.,24.3,0.9,.,0,0 3,4.53,0,0,70.5,1,1 3,.,0.27,4.4,.,0,0 3,.,0.58,6.9,.,0,0 3,.,1.02,8.2,.,0,0 3,.,2.02,7.8,.,0,0 3,.,3.62,7.5,.,0,0 3,.,5.08,6.2,.,0,0 3,.,7.07,5.3,.,0,0 3,.,9,4.9,.,0,0 3,.,12.15,3.7,.,0,0 3,.,24.17,1.05,.,0,0 4,4.4,0,0,72.7,1,1 4,.,0.35,1.89,.,0,0 4,.,0.6,4.6,.,0,0 4,.,1.07,8.6,.,0,0 4,.,2.13,8.38,.,0,0 4,.,3.5,7.54,.,0,0 4,.,5.02,6.88,.,0,0 4,.,7.02,5.78,.,0,0 4,.,9.02,5.33,.,0,0 4,.,11.98,4.19,.,0,0 4,.,24.65,1.15,.,0,0 5,5.86,0,0,54.6,1,1 5,.,0.3,2.02,.,0,0 5,.,0.52,5.63,.,0,0 5,.,1,11.4,.,0,0 5,.,2.02,9.33,.,0,0 5,.,3.5,8.74,.,0,0 5,.,5.02,7.56,.,0,0 5,.,7.02,7.09,.,0,0 5,.,9.1,5.9,.,0,0 5,.,12,4.37,.,0,0 5,.,24.35,1.57,.,0,0 6,4.,0,0,80.,1,1 6,.,0.27,1.29,.,0,0 6,.,0.58,3.08,.,0,0 6,.,1.15,6.44,.,0,0 6,.,2.03,6.32,.,0,0 6,.,3.57,5.53,.,0,0 6,.,5,4.94,.,0,0 6,.,7,4.02,.,0,0 6,.,9.22,3.46,.,0,0 6,.,12.1,2.78,.,0,0 6,.,23.85,0.92,.,0,0 7,4.95,0,0.15,64.6,1,1 7,.,0.25,0.85,.,0,0 7,.,0.5,2.35,.,0,0 7,.,1.02,5.02,.,0,0 7,.,2.02,6.58,.,0,0 7,.,3.48,7.09,.,0,0 7,.,5,6.66,.,0,0 7,.,6.98,5.25,.,0,0 7,.,9,4.39,.,0,0 7,.,12.05,3.53,.,0,0 7,.,24.22,1.15,.,0,0 8,4.53,0,0,70.5,1,1 8,.,0.25,3.05,.,0,0 8,.,0.52,3.05,.,0,0 8,.,0.98,7.31,.,0,0 8,.,2.02,7.56,.,0,0 8,.,3.53,6.59,.,0,0 8,.,5.05,5.88,.,0,0 8,.,7.15,4.73,.,0,0 8,.,9.07,4.57,.,0,0 8,.,12.1,3,.,0,0 8,.,24.12,1.25,.,0,0 9,3.1,0,0,86.4,1,1 9,.,0.3,7.37,.,0,0 9,.,0.63,9.03,.,0,0 9,.,1.05,7.14,.,0,0 9,.,2.02,6.33,.,0,0 9,.,3.53,5.66,.,0,0 9,.,5.02,5.67,.,0,0 9,.,7.17,4.24,.,0,0 9,.,8.8,4.11,.,0,0 9,.,11.6,3.16,.,0,0 9,.,24.43,1.12,.,0,0 10,5.5,0,0.24,58.2,1,1 10,.,0.37,2.89,.,0,0 10,.,0.77,5.22,.,0,0 10,.,1.02,6.41,.,0,0 10,.,2.05,7.83,.,0,0 10,.,3.55,10.21,.,0,0 10,.,5.05,9.18,.,0,0 10,.,7.08,8.02,.,0,0 10,.,9.38,7.14,.,0,0 10,.,12.1,5.68,.,0,0 10,.,23.7,2.42,.,0,0 11,4.92,0,0,65.,1,1 11,.,0.25,4.86,.,0,0 11,.,0.5,7.24,.,0,0 11,.,0.98,8,.,0,0 11,.,1.98,6.81,.,0,0 11,.,3.6,5.87,.,0,0 11,.,5.02,5.22,.,0,0 11,.,7.03,4.45,.,0,0 11,.,9.03,3.62,.,0,0 11,.,12.12,2.69,.,0,0 11,.,24.08,0.86,.,0,0 12,5.3,0,0,60.5,1,1 12,.,0.25,1.25,.,0,0 12,.,0.5,3.96,.,0,0 12,.,1,7.82,.,0,0 12,.,2,9.72,.,0,0 12,.,3.52,9.75,.,0,0 12,.,5.07,8.57,.,0,0 12,.,7.07,6.59,.,0,0 12,.,9.03,6.11,.,0,0 12,.,12.05,4.57,.,0,0 12,.,24.15,1.17,.,0,0
你当前的JAX/JAXOPT代码
import jax import jax.numpy as jnp import pandas as pd import jaxopt as jaxopt import time # Enable 64-bit precision jax.config.update("jax_enable_x64", True) # Calculate the sqrt of the inverse of a matrix @jax.jit def mat_sqrt_inv(mat): eigvals, eigvecs = jnp.linalg.eigh(mat) d2 = 1.0 / jnp.sqrt(jnp.abs(eigvals)) return eigvecs @ jnp.diag(d2) @ eigvecs.T # Calculate the PRED value @jax.jit def PRED(THETA, ETA, DATA): DOSE = 320.0 TIME = DATA[:, 1] # TIME values KA = THETA[0] * jnp.exp(ETA[0]) V = THETA[1] * jnp.exp(ETA[1]) K = THETA[2] * jnp.exp(ETA[2]) F = DOSE / V * KA / (KA - K) * (jnp.exp(-K * TIME) - jnp.exp(-KA * TIME)) G1 = ((DOSE / V / (KA - K) - DOSE / V * KA / (KA - K) ** 2) * (jnp.exp(-K * TIME) - jnp.exp(-KA * TIME)) + DOSE / V * KA / (KA - K) * (jnp.exp(-KA * TIME) * TIME)) * KA G2 = -(DOSE / V ** 2 * KA / (KA - K) * (jnp.exp(-K * TIME) - jnp.exp(-KA * TIME))) * V G3 = (DOSE / V * KA / (KA - K) ** 2 * (jnp.exp(-K * TIME) - jnp.exp(-KA * TIME)) - DOSE / V * KA / (KA - K) * (jnp.exp(-K * TIME) * TIME)) * K H1 = F H2 = jnp.ones_like(F) return jnp.stack([F, G1, G2, G3, H1, H2], axis=1) # Read data df = pd.read_csv("THEOPH.csv") df = df[df['EVID'] == 0] ID = df['ID'].unique() NETA = 3 NEPS = 2 selected_columns = ['ID', 'TIME', 'DV'] DATASET = jnp.array(df[selected_columns].values) _subject_data_list = [] for subject_id_val in sorted(ID): # Iterate through unique IDs in sorted order _subject_data_list.append(DATASET[DATASET[:, 0] == subject_id_val]) # Final NONMEM estimates THETA = jnp.array([2.8864797758451806, 33.693922520723312, 8.7260954693777815E-002]) OMEGA = jnp.array([ [0.88639937186299267, 0, 0], [0, 2.0903952408947064E-002, 0], [0, 0, 6.8919408612682198E-002] ]) SIGMA = jnp.array([ [9.9259732296116312E-003, 0], [0, 3.2931599977662498E-002] ]) # Objective function (补充完整群体模型对数似然逻辑) def obj_fn(params, subject_data_list): THETA, OMEGA, SIGMA = params total_loglik = 0.0 for subject_data in subject_data_list: DV = subject_data[:, 2] pred_matrix = PRED(THETA, jnp.zeros(NETA), subject_data) pred = pred_matrix[:, 0] res = DV - pred # 简化的个体对数似然计算 loglik_ind = -0.5 * (jnp.dot(res, res) / SIGMA[0,0] + len(res)*jnp.log(2*jnp.pi*SIGMA[0,0])) total_loglik += loglik_ind return -total_loglik # JAXOPT默认最小化,返回负对数似然 # 优化器初始化 lbfgsb = jaxopt.LBFGSB(fun=obj_fn, maxiter=100) params_init = (THETA, OMEGA, SIGMA) # 运行优化并计时 start_time = time.time() result = lbfgsb.run(params_init, subject_data_list=_subject_data_list) end_time = time.time() print(f"Optimization took {end_time - start_time:.2f} seconds") print(f"Optimized THETA: {result.params[0]}")
为什么你的JAX代码这么慢?核心原因
- JIT编译用得不全:你给
mat_sqrt_inv和PRED加了@jax.jit,但最耗时的obj_fn没加!JAX的性能优势完全依赖XLA编译,核心函数不JIT等于白用JAX。 - Python循环拖垮性能:
obj_fn里用Python循环遍历每个受试者,JIT无法编译Python循环,每次迭代都要回到Python解释器,完全浪费了JAX的向量化能力。 - 数据结构不适合JAX:用Python列表存每个受试者的数据,JAX很难对其做向量化处理,更适合用统一形状的张量。
- 重复计算冗余:
PRED函数里多次重复计算jnp.exp(-K * TIME)、ka_minus_k等项,增加了不必要的运算量。
具体优化建议
1. 用JIT+Vmap替代Python循环
这是最核心的优化!用jax.vmap把单受试者处理逻辑自动映射到所有受试者,再给obj_fn加@jax.jit:
# 先定义单受试者似然计算 @jax.jit def single_subject_loglik(THETA, OMEGA,




