加速Python中Metropolis-Hastings算法:MCMC采样性能优化问询
Problem Statement
I have a Metropolis-Hastings MCMC implementation to sample a posterior distribution, using SciPy for random sampling and PDF calculations:
import numpy as np from scipy import stats def get_samples(n): """ Generate and return a randomly sampled posterior. For simplicity, Prior is fixed as Beta(a=2,b=5), Likelihood is fixed as Normal(0,2) :type n: int :param n: number of iterations :rtype: numpy.ndarray """ x_t = stats.uniform(0,1).rvs() # initial value posterior = np.zeros((n,)) for t in range(n): x_prime = stats.norm(loc=x_t).rvs() # candidate p1 = stats.beta(a=2,b=5).pdf(x_prime)*stats.norm(loc=0,scale=2).pdf(x_prime) # prior * likelihood p2 = stats.beta(a=2,b=5).pdf(x_t)*stats.norm(loc=0,scale=2).pdf(x_t) # prior * likelihood alpha = p1/p2 # ratio u = stats.uniform(0,1).rvs() # random uniform if u <= alpha: x_t = x_prime # accept posterior[t] = x_t posterior = posterior[np.where(posterior > 0)] # get rid of initial zeros that don't contribute to distribution return posterior
I prefer pure NumPy implementations to avoid explicit Python loops, but the Metropolis-Hastings algorithm's sequential conditional steps make this impossible, leading to slow runtime. Performance profiling shows most time is spent inside the loop: specifically random number generation, stats.beta().pdf(), and stats.norm().pdf() calls.
I tried optimizing with Numba, but Numba's support for random number generation is limited (only CUDA RNG supports normal and uniform distributions natively).
Questions:
- Are there Numba-compatible multi-distribution sampling schemes that can significantly speed up this code?
- More generally, what are effective methods to accelerate sampling from distributions like Beta, Gamma, Poisson in Python loops?
Great question—MCMC loops are inherently sequential, but there are several ways to cut down on the overhead from distribution calls and random number generation. Let's break down the solutions, starting with Numba-compatible approaches and moving to more general optimizations.
1. Optimizing with Numba (Using Numba's Native RNG and Precompiled PDFs)
Numba does support CPU-based random number generation for common distributions—you just need to use the numba.random module instead of SciPy/Numpy. It also lets you precompile PDF calculations to avoid the overhead of SciPy's class-based method calls.
Here's how to rewrite your function with Numba:
import numba from numba import njit from numba.random import create_xoroshiro128p_state, uniform_float64, normal_float64 import numpy as np @njit def beta_pdf(x, a, b): # Precompute beta function constant for fixed a=2, b=5 to save time beta_const = 1.0 / 0.03571428571428571 # 1/B(2,5) where B(a,b)=Gamma(a)Gamma(b)/Gamma(a+b) return beta_const * (x ** (a-1)) * ((1 - x) ** (b-1)) if 0 < x < 1 else 0.0 @njit def norm_pdf(x, loc, scale): inv_scale = 1.0 / scale z = (x - loc) * inv_scale return (1.0 / (scale * np.sqrt(2 * np.pi))) * np.exp(-0.5 * z**2) @njit def get_samples_numba(n, seed=42): rng_state = create_xoroshiro128p_state(seed) x_t = uniform_float64(rng_state, 0.0, 1.0) posterior = np.zeros(n) for t in range(n): # Generate candidate from normal distribution x_prime = normal_float64(rng_state, x_t, 1.0) # loc=x_t, scale=1.0 (matches original code) # Calculate prior * likelihood for candidate and current state p1 = beta_pdf(x_prime, 2, 5) * norm_pdf(x_prime, 0.0, 2.0) p2 = beta_pdf(x_t, 2, 5) * norm_pdf(x_t, 0.0, 2.0) alpha = p1 / p2 # Generate uniform for acceptance u = uniform_float64(rng_state, 0.0, 1.0) if u <= alpha: x_t = x_prime posterior[t] = x_t # Filter out non-positive values (though with proper sampling this should be rare) posterior = posterior[posterior > 0] return posterior
Key Optimizations Here:
- Numba's Native RNG:
create_xoroshiro128p_stateand the*_float64functions are compiled to machine code, avoiding Python-level overhead. - Precompiled PDFs: Instead of calling SciPy's
stats.beta.pdf()(which has class instantiation and validation overhead), we wrote simple JIT-compiled PDF functions. For fixeda=2andb=5, we precomputed the beta function constant to avoid redundant calculations. - @njit Decorator: Compiles the entire function to optimized machine code, eliminating Python loop overhead.
This should give you a 10-100x speedup over your original code, depending on n.
2. General Accelerations for Multi-Distribution Sampling in Loops
If you're not tied to Numba, here are other strategies:
a. Vectorize Where Possible (Even in Sequential Algorithms)
While Metropolis-Hastings requires sequential acceptance/rejection, you can vectorize batches of candidate generation and PDF calculations if you use a "batch" MCMC approach (though this changes the algorithm slightly). For example, generate multiple candidates at once and accept/reject in batches, but this works best for certain proposal distributions.
b. Use Specialized MCMC Libraries
Libraries like PyMC3 or Stan via PyStan are optimized for MCMC and handle the low-level optimizations for you. They use compiled backends (C++ for Stan, Theano/Aesara for PyMC3) that are far faster than handwritten Python loops. For your example, PyMC3 code would look like this:
import pymc3 as pm with pm.Model() as model: prior = pm.Beta('prior', alpha=2, beta=5) likelihood = pm.Normal('likelihood', mu=0, sigma=2, observed=[]) # No data, sampling prior*likelihood trace = pm.sample(n, tune=0, chains=1, cores=1) posterior = trace['prior']
This is not only faster but also handles convergence checks and other MCMC best practices automatically.
c. Use Cython for Low-Level Control
If you want even more control than Numba, Cython lets you write C-level code with Python syntax. You can use C libraries for random number generation (like GSL's random functions) and compile PDF calculations directly, which can match or exceed Numba's performance. However, this requires more boilerplate code than Numba.
d. Precompute Constants
For fixed distribution parameters (like your Beta(2,5) and Normal(0,2)), precompute all constants in the PDF formulas upfront. This avoids recalculating things like 1/(sigma*sqrt(2pi)) or the beta function in every loop iteration—something we did in the Numba example.
3. Handling Less Common Distributions (Gamma, Poisson) with Numba
For distributions like Gamma or Poisson, Numba's numba.random module has you covered:
- Gamma: Use
gamma_float64(rng_state, shape, scale) - Poisson: Use
poisson_int64(rng_state, lam)
For their PDFs, you can write JIT-compiled functions similar to the Beta and Normal examples, using Numba's math functions (like numba.math.gamma for the Gamma function in Gamma PDF calculations).
内容的提问来源于stack exchange,提问作者PyRsquared




