C.W.K.
Stream
Lesson 03 of 05 · published

NumPyro 로 Bayesian Inference

~8 min · scientific, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

NumPyro — Pyro 의 JAX backend. PyTorch 의 Pyro 가 dynamic, NumPyro 가 JAX 의 functional + jit 가속.

pip install numpyro

가장 단순 — 1D Gaussian mean estimation

import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

def model(data):
    '''y ~ Normal(mu, 1), mu ~ Normal(0, 10)'''
    mu = numpyro.sample("mu", dist.Normal(0., 10.))
    with numpyro.plate("data", len(data)):
        numpyro.sample("obs", dist.Normal(mu, 1.), obs=data)

# 합성 데이터
import numpy as np
true_mu = 3.7
data = jax.random.normal(jax.random.PRNGKey(0), (100,)) + true_mu

# MCMC
nuts = NUTS(model)
mcmc = MCMC(nuts, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(1), data=data)

# 결과
samples = mcmc.get_samples()
print(f"posterior mu: mean={samples['mu'].mean():.3f}, std={samples['mu'].std():.3f}")
print(f"true mu: {true_mu}")

NUTS — No-U-Turn Sampler, HMC 의 자동 tuning 버전. 매 step 의 gradient 가 — JAX 의 autodiff 로 자동 jit. PyTorch Pyro 보다 — 큰 model 에서 5-10x 빠른 게 흔함.

linear regression

def linear_model(X, y=None):
    n_features = X.shape[1]
    w = numpyro.sample("w", dist.Normal(0., 1.).expand([n_features]))
    b = numpyro.sample("b", dist.Normal(0., 1.))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1.))

    mean = jnp.dot(X, w) + b
    with numpyro.plate("data", len(X)):
        numpyro.sample("y", dist.Normal(mean, sigma), obs=y)

# 학습
mcmc = MCMC(NUTS(linear_model), num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), X=X_train, y=y_train)

# prediction (posterior predictive)
samples = mcmc.get_samples()
y_pred = jnp.dot(X_test, samples["w"].T) + samples["b"]   # (n_samples, n_test)
y_pred_mean = y_pred.mean(0)
y_pred_std = y_pred.std(0)

uncertainty 추정이 자연스럽게 — Bayesian 의 가장 큰 장점.

variational inference (VI)

MCMC 는 정확하지만 큰 데이터엔 느림. VI 는 빠른 근사.

from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
import optax

guide = AutoNormal(linear_model)
optimizer = numpyro.optim.optax_to_numpyro(optax.adam(0.01))
svi = SVI(linear_model, guide, optimizer, loss=Trace_ELBO())

svi_result = svi.run(jax.random.PRNGKey(0), 5000, X=X_train, y=y_train)

# guide 로부터 posterior approximation 추출
params = svi_result.params

실전 — hierarchical model

def hierarchical(group_idx, x, y=None):
    '''그룹별 random effect'''
    n_groups = len(jnp.unique(group_idx))

    # population-level
    mu_w = numpyro.sample("mu_w", dist.Normal(0., 1.))
    sigma_w = numpyro.sample("sigma_w", dist.HalfNormal(1.))

    # group-level
    with numpyro.plate("groups", n_groups):
        w_g = numpyro.sample("w_g", dist.Normal(mu_w, sigma_w))

    # data-level
    with numpyro.plate("data", len(x)):
        mean = w_g[group_idx] * x
        numpyro.sample("y", dist.Normal(mean, 0.1), obs=y)

Bayesian hierarchy — JAX 위에서 깔끔히. 큰 데이터 + 많은 group 도 GPU 에서 빠르게.

🎲 NumPyro 의 가치

Bayesian inference 의 두 큰 비용 — gradient 평가 (모든 sample step 마다) + sample 의 sequential 의존성. JAX 의 jit + grad 가 첫 번째를 — 자동으로 — 가속. NUTS 의 sequential 부분만 남음. 결과: 100x model parameter, 100x data 의 inference 가 — 학생 노트북에서.

JAX-native Bayesian 도구 다른 옵션 — BlackJAX (모듈식 sampler), tfp.substrates.jax (TF Probability 의 JAX backend). 각자 강점 — NumPyro 가 가장 사용자 friendly.

Code

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
import jax
import jax.numpy as jnp

# Define a Bayesian linear regression model
def linear_regression(x, y=None):
    # Priors
    alpha = numpyro.sample('alpha', dist.Normal(0, 10))
    beta = numpyro.sample('beta', dist.Normal(0, 10))
    sigma = numpyro.sample('sigma', dist.HalfNormal(5))

    # Likelihood
    mu = alpha + beta * x
    numpyro.sample('obs', dist.Normal(mu, sigma), obs=y)

# Generate synthetic data
key = jax.random.key(0)
true_alpha, true_beta = 2.0, 3.5
x = jnp.linspace(-5, 5, 100)
y = true_alpha + true_beta * x + 0.5 * jax.random.normal(key, (100,))

# Run MCMC with NUTS (No U-Turn Sampler)
kernel = NUTS(linear_regression)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.key(1), x=x, y=y)

# Get posterior samples
samples = mcmc.get_samples()
print(f"alpha: {samples['alpha'].mean():.2f} ± {samples['alpha'].std():.2f}")
print(f"beta:  {samples['beta'].mean():.2f} ± {samples['beta'].std():.2f}")
# alpha: 2.00 ± 0.05  (true: 2.0)
# beta:  3.50 ± 0.01  (true: 3.5)

# Make predictions
predictive = Predictive(linear_regression, samples)
predictions = predictive(jax.random.key(2), x=jnp.array([0.0, 1.0, 2.0]))
print(f"Predictions: {predictions['obs'].mean(axis=0)}")
# ≈ [2.0, 5.5, 9.0]
import jax
import jax.numpy as jnp

# Monte Carlo estimation of pi using vmap
def estimate_pi(key, num_samples):
    keys = jax.random.split(key, 2)
    x = jax.random.uniform(keys[0], (num_samples,))
    y = jax.random.uniform(keys[1], (num_samples,))
    inside_circle = (x**2 + y**2) <= 1.0
    return 4.0 * jnp.mean(inside_circle)

# Run 100 independent estimates in parallel
keys = jax.random.split(jax.random.key(0), 100)
estimates = jax.vmap(estimate_pi, in_axes=(0, None))(keys, 10000)
print(f"π ≈ {estimates.mean():.4f} ± {estimates.std():.4f}")
# π ≈ 3.1415 ± 0.0162

External links

Exercise

NumPyro NUTS sampler 로 단순 Gaussian-mean model fit. posterior trace 검사. JAX 의 가치: NUTS 의 inner gradient 평가가 자동 JIT.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.