Diffrax — JAX-native ODE/SDE/CDE 솔버. Patrick Kidger (Equinox 작가) 작품. 모든 solver 가 jit/grad/vmap 호환.
pip install diffrax
가장 단순한 ODE
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt
def vector_field(t, y, args):
'''dy/dt = f(t, y)'''
return -y # exponential decay
solution = diffeqsolve(
ODETerm(vector_field),
solver=Tsit5(), # 5th-order Tsitouras
t0=0.0, t1=5.0, dt0=0.1,
y0=jnp.array(1.0),
saveat=SaveAt(ts=jnp.linspace(0, 5, 100)),
)
print(solution.ts.shape, solution.ys.shape) # (100,) (100,)
다양한 solver — Heun, Dopri5, Dopri8 (high-order), implicit (KenCarp 등 stiff problems). 같은 API.
parameterized vector field
def lorenz(t, y, args):
'''Lorenz attractor'''
sigma, rho, beta = args
x, y_, z = y
return jnp.array([
sigma * (y_ - x),
x * (rho - z) - y_,
x * y_ - beta * z,
])
sol = diffeqsolve(
ODETerm(lorenz),
Tsit5(),
t0=0., t1=10., dt0=0.01,
y0=jnp.array([1., 1., 1.]),
args=(10.0, 28.0, 8/3), # Lorenz parameter
saveat=SaveAt(ts=jnp.linspace(0, 10, 1000)),
)
가장 강력한 부분 — gradient 흐름
import jax
def simulate(initial_state, params):
sol = diffeqsolve(
ODETerm(lorenz),
Tsit5(),
t0=0., t1=5., dt0=0.01,
y0=initial_state,
args=params,
)
return sol.ys[-1] # 마지막 state
# 시작 state 에 대한 final state 의 gradient
grad_initial = jax.grad(lambda y0: jnp.sum(simulate(y0, params)))
print(grad_initial(jnp.array([1., 1., 1.])))
# ODE solver 를 통해 자동 미분된 gradient
이게 — 다른 framework 에서는 — 별도의 adjoint sensitivity 코드 작성. JAX/Diffrax 는 — 그냥 grad. PyTorch torchdiffeq 가 비슷하지만, JAX 의 ergonomics 가 더 깨끗.
Neural ODE — 모델이 ODE
class NeuralODE(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, key):
self.mlp = eqx.nn.MLP(
in_size=3, out_size=3, width_size=64, depth=3, key=key,
)
def __call__(self, t, y, args):
return self.mlp(y)
model = NeuralODE(jax.random.PRNGKey(0))
def integrate(model, y0, t1):
sol = diffeqsolve(
ODETerm(model),
Tsit5(),
t0=0., t1=t1, dt0=0.01,
y0=y0,
)
return sol.ys[-1]
# 학습 — model 의 weight 가 ODE vector field 를 정의
def loss(model, y0, t1, target):
pred = integrate(model, y0, t1)
return jnp.mean((pred - target) ** 2)
grads = jax.grad(loss)(model, y0, 5.0, target)
SDE (확률 미분방정식)
from diffrax import SDESolver, ItoSDE, VirtualBrownianTree
def drift(t, y, args): return -y
def diffusion(t, y, args): return 0.1 * jnp.eye(2)
bm = VirtualBrownianTree(t0=0, t1=10, tol=1e-3, shape=(2,), key=jax.random.PRNGKey(0))
sol = diffeqsolve(
ItoSDE(drift, diffusion),
SDESolver(...),
t0=0., t1=10., dt0=0.01,
y0=jnp.zeros(2),
args=None,
)
🔬 Diffrax 의 위치
scipy.integrate 는 great solver — 그러나 JAX autodiff 시스템과 통합 안 됨. Diffrax 는 — JAX-native, jit/grad/vmap 호환, GPU/TPU 가속. ODE/SDE/CDE 다. Neural ODE, differentiable physics, sensitivity analysis 의 기반. JAX 가 과학 계산에서 빛나는 1 등 사례.
참고 — 같은 작가의 Optimistix (root finding, fixed-point), Lineax (linear solvers). JAX 의 과학 계산 stack 이 빠르게 자라는 중.