NumPyro — Pyro 의 JAX backend. PyTorch 의 Pyro 가 dynamic, NumPyro 가 JAX 의 functional + jit 가속.
pip install numpyro
가장 단순 — 1D Gaussian mean estimation
import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
def model(data):
'''y ~ Normal(mu, 1), mu ~ Normal(0, 10)'''
mu = numpyro.sample("mu", dist.Normal(0., 10.))
with numpyro.plate("data", len(data)):
numpyro.sample("obs", dist.Normal(mu, 1.), obs=data)
# 합성 데이터
import numpy as np
true_mu = 3.7
data = jax.random.normal(jax.random.PRNGKey(0), (100,)) + true_mu
# MCMC
nuts = NUTS(model)
mcmc = MCMC(nuts, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(1), data=data)
# 결과
samples = mcmc.get_samples()
print(f"posterior mu: mean={samples['mu'].mean():.3f}, std={samples['mu'].std():.3f}")
print(f"true mu: {true_mu}")
NUTS — No-U-Turn Sampler, HMC 의 자동 tuning 버전. 매 step 의 gradient 가 — JAX 의 autodiff 로 자동 jit. PyTorch Pyro 보다 — 큰 model 에서 5-10x 빠른 게 흔함.
linear regression
def linear_model(X, y=None):
n_features = X.shape[1]
w = numpyro.sample("w", dist.Normal(0., 1.).expand([n_features]))
b = numpyro.sample("b", dist.Normal(0., 1.))
sigma = numpyro.sample("sigma", dist.HalfNormal(1.))
mean = jnp.dot(X, w) + b
with numpyro.plate("data", len(X)):
numpyro.sample("y", dist.Normal(mean, sigma), obs=y)
# 학습
mcmc = MCMC(NUTS(linear_model), num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), X=X_train, y=y_train)
# prediction (posterior predictive)
samples = mcmc.get_samples()
y_pred = jnp.dot(X_test, samples["w"].T) + samples["b"] # (n_samples, n_test)
y_pred_mean = y_pred.mean(0)
y_pred_std = y_pred.std(0)
uncertainty 추정이 자연스럽게 — Bayesian 의 가장 큰 장점.
variational inference (VI)
MCMC 는 정확하지만 큰 데이터엔 느림. VI 는 빠른 근사.
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
import optax
guide = AutoNormal(linear_model)
optimizer = numpyro.optim.optax_to_numpyro(optax.adam(0.01))
svi = SVI(linear_model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 5000, X=X_train, y=y_train)
# guide 로부터 posterior approximation 추출
params = svi_result.params
실전 — hierarchical model
def hierarchical(group_idx, x, y=None):
'''그룹별 random effect'''
n_groups = len(jnp.unique(group_idx))
# population-level
mu_w = numpyro.sample("mu_w", dist.Normal(0., 1.))
sigma_w = numpyro.sample("sigma_w", dist.HalfNormal(1.))
# group-level
with numpyro.plate("groups", n_groups):
w_g = numpyro.sample("w_g", dist.Normal(mu_w, sigma_w))
# data-level
with numpyro.plate("data", len(x)):
mean = w_g[group_idx] * x
numpyro.sample("y", dist.Normal(mean, 0.1), obs=y)
Bayesian hierarchy — JAX 위에서 깔끔히. 큰 데이터 + 많은 group 도 GPU 에서 빠르게.
🎲 NumPyro 의 가치
Bayesian inference 의 두 큰 비용 — gradient 평가 (모든 sample step 마다) + sample 의 sequential 의존성. JAX 의 jit + grad 가 첫 번째를 — 자동으로 — 가속. NUTS 의 sequential 부분만 남음. 결과: 100x model parameter, 100x data 의 inference 가 — 학생 노트북에서.
JAX-native Bayesian 도구 다른 옵션 — BlackJAX (모듈식 sampler), tfp.substrates.jax (TF Probability 의 JAX backend). 각자 강점 — NumPyro 가 가장 사용자 friendly.