C.W.K.
Stream
Lesson 02 of 05 · published

Differential Equations 와 Diffrax

~9 min · scientific, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

Diffrax — JAX-native ODE/SDE/CDE 솔버. Patrick Kidger (Equinox 작가) 작품. 모든 solver 가 jit/grad/vmap 호환.

pip install diffrax

가장 단순한 ODE

import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt

def vector_field(t, y, args):
    '''dy/dt = f(t, y)'''
    return -y   # exponential decay

solution = diffeqsolve(
    ODETerm(vector_field),
    solver=Tsit5(),    # 5th-order Tsitouras
    t0=0.0, t1=5.0, dt0=0.1,
    y0=jnp.array(1.0),
    saveat=SaveAt(ts=jnp.linspace(0, 5, 100)),
)

print(solution.ts.shape, solution.ys.shape)   # (100,) (100,)

다양한 solver — Heun, Dopri5, Dopri8 (high-order), implicit (KenCarp 등 stiff problems). 같은 API.

parameterized vector field

def lorenz(t, y, args):
    '''Lorenz attractor'''
    sigma, rho, beta = args
    x, y_, z = y
    return jnp.array([
        sigma * (y_ - x),
        x * (rho - z) - y_,
        x * y_ - beta * z,
    ])

sol = diffeqsolve(
    ODETerm(lorenz),
    Tsit5(),
    t0=0., t1=10., dt0=0.01,
    y0=jnp.array([1., 1., 1.]),
    args=(10.0, 28.0, 8/3),   # Lorenz parameter
    saveat=SaveAt(ts=jnp.linspace(0, 10, 1000)),
)

가장 강력한 부분 — gradient 흐름

import jax

def simulate(initial_state, params):
    sol = diffeqsolve(
        ODETerm(lorenz),
        Tsit5(),
        t0=0., t1=5., dt0=0.01,
        y0=initial_state,
        args=params,
    )
    return sol.ys[-1]   # 마지막 state

# 시작 state 에 대한 final state 의 gradient
grad_initial = jax.grad(lambda y0: jnp.sum(simulate(y0, params)))
print(grad_initial(jnp.array([1., 1., 1.])))
# ODE solver 를 통해 자동 미분된 gradient

이게 — 다른 framework 에서는 — 별도의 adjoint sensitivity 코드 작성. JAX/Diffrax 는 — 그냥 grad. PyTorch torchdiffeq 가 비슷하지만, JAX 의 ergonomics 가 더 깨끗.

Neural ODE — 모델이 ODE

class NeuralODE(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, key):
        self.mlp = eqx.nn.MLP(
            in_size=3, out_size=3, width_size=64, depth=3, key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y)

model = NeuralODE(jax.random.PRNGKey(0))

def integrate(model, y0, t1):
    sol = diffeqsolve(
        ODETerm(model),
        Tsit5(),
        t0=0., t1=t1, dt0=0.01,
        y0=y0,
    )
    return sol.ys[-1]

# 학습 — model 의 weight 가 ODE vector field 를 정의
def loss(model, y0, t1, target):
    pred = integrate(model, y0, t1)
    return jnp.mean((pred - target) ** 2)

grads = jax.grad(loss)(model, y0, 5.0, target)

SDE (확률 미분방정식)

from diffrax import SDESolver, ItoSDE, VirtualBrownianTree

def drift(t, y, args): return -y
def diffusion(t, y, args): return 0.1 * jnp.eye(2)

bm = VirtualBrownianTree(t0=0, t1=10, tol=1e-3, shape=(2,), key=jax.random.PRNGKey(0))
sol = diffeqsolve(
    ItoSDE(drift, diffusion),
    SDESolver(...),
    t0=0., t1=10., dt0=0.01,
    y0=jnp.zeros(2),
    args=None,
)

🔬 Diffrax 의 위치

scipy.integrate 는 great solver — 그러나 JAX autodiff 시스템과 통합 안 됨. Diffrax 는 — JAX-native, jit/grad/vmap 호환, GPU/TPU 가속. ODE/SDE/CDE 다. Neural ODE, differentiable physics, sensitivity analysis 의 기반. JAX 가 과학 계산에서 빛나는 1 등 사례.

참고 — 같은 작가의 Optimistix (root finding, fixed-point), Lineax (linear solvers). JAX 의 과학 계산 stack 이 빠르게 자라는 중.

Code

import diffrax
import jax.numpy as jnp

# Solve a simple ODE: dy/dt = -y (exponential decay)
def vector_field(t, y, args):
    return -y

term = diffrax.ODETerm(vector_field)
solver = diffrax.Tsit5()  # 5th-order Runge-Kutta (recommended default)
saveat = diffrax.SaveAt(ts=jnp.linspace(0, 5, 100))
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)

sol = diffrax.diffeqsolve(
    term,
    solver,
    t0=0,
    t1=5,
    dt0=0.1,
    y0=1.0,
    saveat=saveat,
    stepsize_controller=stepsize_controller,
)

print(sol.ts.shape)  # (100,) — time points
print(sol.ys.shape)  # (100,) — solution values
# sol.ys ≈ exp(-ts) — the exact solution
import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp

class NeuralODE(eqx.Module):
    """A Neural ODE: dy/dt = neural_network(t, y)"""
    net: eqx.nn.MLP

    def __init__(self, data_dim, hidden_dim, *, key):
        self.net = eqx.nn.MLP(
            in_size=data_dim + 1,  # +1 for time
            out_size=data_dim,
            width_size=hidden_dim,
            depth=2,
            key=key,
        )

    def __call__(self, t, y, args=None):
        t_expanded = jnp.broadcast_to(t, (1,))
        inp = jnp.concatenate([y, t_expanded])
        return self.net(inp)

# Create the Neural ODE model
key = jax.random.key(42)
model = NeuralODE(data_dim=2, hidden_dim=64, key=key)

# Solve the ODE (forward pass)
term = diffrax.ODETerm(model)
solver = diffrax.Tsit5()
controller = diffrax.PIDController(rtol=1e-3, atol=1e-3)

y0 = jnp.array([1.0, 0.0])
sol = diffrax.diffeqsolve(
    term, solver, t0=0, t1=1, dt0=0.1, y0=y0,
    stepsize_controller=controller,
)

# The whole thing is differentiable!
@eqx.filter_jit
@eqx.filter_value_and_grad
def loss_fn(model):
    term = diffrax.ODETerm(model)
    sol = diffrax.diffeqsolve(
        term, diffrax.Tsit5(), t0=0, t1=1, dt0=0.1, y0=y0,
        stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-3),
    )
    return jnp.mean(sol.ys[-1] ** 2)  # minimize final state

loss, grads = loss_fn(model)
print(f"Loss: {loss:.6f}")

External links

Exercise

Diffrax 로 Lotka-Volterra ODE 풀기. final state 의 한 parameter 에 대한 grad. 'solver 통해 미분' 가 — JAX 가 ML 외부에서 특별한 이유.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.