You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

加速Python中Metropolis-Hastings算法:MCMC采样性能优化问询

Accelerating Metropolis-Hastings with Multi-Distribution Sampling in Python (Avoiding Slow Loops)

Problem Statement

I have a Metropolis-Hastings MCMC implementation to sample a posterior distribution, using SciPy for random sampling and PDF calculations:

import numpy as np
from scipy import stats

def get_samples(n):
    """
    Generate and return a randomly sampled posterior.
    For simplicity, Prior is fixed as Beta(a=2,b=5), Likelihood is fixed as Normal(0,2)
    :type n: int
    :param n: number of iterations
    :rtype: numpy.ndarray
    """
    x_t = stats.uniform(0,1).rvs() # initial value
    posterior = np.zeros((n,))
    for t in range(n):
        x_prime = stats.norm(loc=x_t).rvs() # candidate
        p1 = stats.beta(a=2,b=5).pdf(x_prime)*stats.norm(loc=0,scale=2).pdf(x_prime) # prior * likelihood
        p2 = stats.beta(a=2,b=5).pdf(x_t)*stats.norm(loc=0,scale=2).pdf(x_t) # prior * likelihood
        alpha = p1/p2 # ratio
        u = stats.uniform(0,1).rvs() # random uniform
        if u <= alpha:
            x_t = x_prime # accept
        posterior[t] = x_t
    posterior = posterior[np.where(posterior > 0)] # get rid of initial zeros that don't contribute to distribution
    return posterior

I prefer pure NumPy implementations to avoid explicit Python loops, but the Metropolis-Hastings algorithm's sequential conditional steps make this impossible, leading to slow runtime. Performance profiling shows most time is spent inside the loop: specifically random number generation, stats.beta().pdf(), and stats.norm().pdf() calls.

I tried optimizing with Numba, but Numba's support for random number generation is limited (only CUDA RNG supports normal and uniform distributions natively).

Questions:

  • Are there Numba-compatible multi-distribution sampling schemes that can significantly speed up this code?
  • More generally, what are effective methods to accelerate sampling from distributions like Beta, Gamma, Poisson in Python loops?

Answer

Great question—MCMC loops are inherently sequential, but there are several ways to cut down on the overhead from distribution calls and random number generation. Let's break down the solutions, starting with Numba-compatible approaches and moving to more general optimizations.

1. Optimizing with Numba (Using Numba's Native RNG and Precompiled PDFs)

Numba does support CPU-based random number generation for common distributions—you just need to use the numba.random module instead of SciPy/Numpy. It also lets you precompile PDF calculations to avoid the overhead of SciPy's class-based method calls.

Here's how to rewrite your function with Numba:

import numba
from numba import njit
from numba.random import create_xoroshiro128p_state, uniform_float64, normal_float64
import numpy as np

@njit
def beta_pdf(x, a, b):
    # Precompute beta function constant for fixed a=2, b=5 to save time
    beta_const = 1.0 / 0.03571428571428571  # 1/B(2,5) where B(a,b)=Gamma(a)Gamma(b)/Gamma(a+b)
    return beta_const * (x ** (a-1)) * ((1 - x) ** (b-1)) if 0 < x < 1 else 0.0

@njit
def norm_pdf(x, loc, scale):
    inv_scale = 1.0 / scale
    z = (x - loc) * inv_scale
    return (1.0 / (scale * np.sqrt(2 * np.pi))) * np.exp(-0.5 * z**2)

@njit
def get_samples_numba(n, seed=42):
    rng_state = create_xoroshiro128p_state(seed)
    x_t = uniform_float64(rng_state, 0.0, 1.0)
    posterior = np.zeros(n)
    
    for t in range(n):
        # Generate candidate from normal distribution
        x_prime = normal_float64(rng_state, x_t, 1.0)  # loc=x_t, scale=1.0 (matches original code)
        # Calculate prior * likelihood for candidate and current state
        p1 = beta_pdf(x_prime, 2, 5) * norm_pdf(x_prime, 0.0, 2.0)
        p2 = beta_pdf(x_t, 2, 5) * norm_pdf(x_t, 0.0, 2.0)
        alpha = p1 / p2
        # Generate uniform for acceptance
        u = uniform_float64(rng_state, 0.0, 1.0)
        if u <= alpha:
            x_t = x_prime
        posterior[t] = x_t
    
    # Filter out non-positive values (though with proper sampling this should be rare)
    posterior = posterior[posterior > 0]
    return posterior

Key Optimizations Here:

  • Numba's Native RNG: create_xoroshiro128p_state and the *_float64 functions are compiled to machine code, avoiding Python-level overhead.
  • Precompiled PDFs: Instead of calling SciPy's stats.beta.pdf() (which has class instantiation and validation overhead), we wrote simple JIT-compiled PDF functions. For fixed a=2 and b=5, we precomputed the beta function constant to avoid redundant calculations.
  • @njit Decorator: Compiles the entire function to optimized machine code, eliminating Python loop overhead.

This should give you a 10-100x speedup over your original code, depending on n.

2. General Accelerations for Multi-Distribution Sampling in Loops

If you're not tied to Numba, here are other strategies:

a. Vectorize Where Possible (Even in Sequential Algorithms)

While Metropolis-Hastings requires sequential acceptance/rejection, you can vectorize batches of candidate generation and PDF calculations if you use a "batch" MCMC approach (though this changes the algorithm slightly). For example, generate multiple candidates at once and accept/reject in batches, but this works best for certain proposal distributions.

b. Use Specialized MCMC Libraries

Libraries like PyMC3 or Stan via PyStan are optimized for MCMC and handle the low-level optimizations for you. They use compiled backends (C++ for Stan, Theano/Aesara for PyMC3) that are far faster than handwritten Python loops. For your example, PyMC3 code would look like this:

import pymc3 as pm

with pm.Model() as model:
    prior = pm.Beta('prior', alpha=2, beta=5)
    likelihood = pm.Normal('likelihood', mu=0, sigma=2, observed=[])  # No data, sampling prior*likelihood
    trace = pm.sample(n, tune=0, chains=1, cores=1)
posterior = trace['prior']

This is not only faster but also handles convergence checks and other MCMC best practices automatically.

c. Use Cython for Low-Level Control

If you want even more control than Numba, Cython lets you write C-level code with Python syntax. You can use C libraries for random number generation (like GSL's random functions) and compile PDF calculations directly, which can match or exceed Numba's performance. However, this requires more boilerplate code than Numba.

d. Precompute Constants

For fixed distribution parameters (like your Beta(2,5) and Normal(0,2)), precompute all constants in the PDF formulas upfront. This avoids recalculating things like 1/(sigma*sqrt(2pi)) or the beta function in every loop iteration—something we did in the Numba example.

3. Handling Less Common Distributions (Gamma, Poisson) with Numba

For distributions like Gamma or Poisson, Numba's numba.random module has you covered:

  • Gamma: Use gamma_float64(rng_state, shape, scale)
  • Poisson: Use poisson_int64(rng_state, lam)

For their PDFs, you can write JIT-compiled functions similar to the Beta and Normal examples, using Numba's math functions (like numba.math.gamma for the Gamma function in Gamma PDF calculations).


内容的提问来源于stack exchange,提问作者PyRsquared

火山引擎 最新活动