JAX 의 ML framework 측면 — 이미 강력. 그런데 진짜 차별점 — ML 아닌 영역에서. JAX 가 차별화되는 이유: 수치 코드를 자유롭게 미분 + 가속.
전통적 과학 계산 stack:
- NumPy / SciPy — single-thread CPU. 미분 없음.
- MATLAB — toolbox 풍부, 그러나 closed-source, gradient 따로.
- Fortran / C++ — 빠르지만 hand-written gradient.
- Numba — JIT, 그러나 미분 없음.
JAX 가 합치는 두 가지:
- 임의 numerical 코드의 자동 미분: ODE 풀이, 시뮬레이션, sampling — 모든 코드 미분 가능.
- 가속기 활용: GPU/TPU 위에서 동일한 Python 코드.
이 둘이 합쳐지면 — 새 카테고리의 알고리즘이 가능해져.
전형적 use case
1. Differentiable physics — 시뮬레이터를 통해 gradient 흘려, 물리 시스템 학습
def simulate_pendulum(initial_angle, length, dt, n_steps):
'''단순한 진자 시뮬레이션'''
theta, omega = initial_angle, 0.0
for _ in range(n_steps):
omega += -9.81 / length * jnp.sin(theta) * dt
theta += omega * dt
return theta
# 진자가 특정 각도에 도달하도록 길이를 학습
target_angle = 0.5
def loss(length):
final = simulate_pendulum(0.1, length, 0.01, 100)
return (final - target_angle) ** 2
# 자동 미분 — 시뮬레이션 통해 gradient 흐름
optimal_length = optimize_with_grad(loss)
2. Differential equations — Diffrax 가 미분 가능한 ODE/SDE solver
from diffrax import diffeqsolve, Tsit5, ODETerm
def lorenz(t, y, args):
x, y_, z = y
return jnp.array([
10 * (y_ - x),
x * (28 - z) - y_,
x * y_ - 8/3 * z,
])
solution = diffeqsolve(
ODETerm(lorenz),
Tsit5(),
t0=0.0, t1=10.0, dt0=0.01,
y0=jnp.array([1., 1., 1.]),
)
3. Bayesian inference — NumPyro 가 JAX 위의 Pyro
import numpyro
import numpyro.distributions as dist
def model(data):
mu = numpyro.sample("mu", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
with numpyro.plate("data", len(data)):
numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)
# NUTS sampler — JAX-jit 으로 자동 가속
mcmc = numpyro.infer.MCMC(numpyro.infer.NUTS(model), num_samples=1000)
mcmc.run(rng_key, data)
4. Optimization — JAXopt, Lineax 같은 라이브러리
import jaxopt
# 비선형 least squares — 자동 미분 + GPU
solver = jaxopt.LevenbergMarquardt(residual_fun=residuals)
result = solver.run(init_params, data=observations)
5. Probabilistic programming — TensorFlow Probability JAX backend
6. Reinforcement learning — Brax (differentiable physics), MJX (MuJoCo)
7. Quantum computing — qujax, PennyLane JAX backend
8. Computational chemistry — JAX-MD (molecular dynamics)
🔬 미분 가능한 시뮬레이션의 의의
전통적 과학 — observation 으로부터 model parameter 추정. JAX-시대 — 시뮬레이션 자체가 미분 가능. 그래서 — gradient descent 로 parameter 학습 + 실험 설계 + sensitivity analysis 가 한 framework. "differentiable everything" 의 패러다임 — 박사 논문 한 권의 가치.
이 quest 의 다음 lesson 들 — Diffrax (ODE), NumPyro (Bayesian), Brax/JAX-MD (시뮬레이션). 각각 — JAX 에코시스템의 한 영역.