C.W.K.
Stream
Lesson 01 of 05 · published

JAX 가 과학 계산에서 빛나는 이유

~8 min · scientific, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 의 ML framework 측면 — 이미 강력. 그런데 진짜 차별점 — ML 아닌 영역에서. JAX 가 차별화되는 이유: 수치 코드를 자유롭게 미분 + 가속.

전통적 과학 계산 stack:

  • NumPy / SciPy — single-thread CPU. 미분 없음.
  • MATLAB — toolbox 풍부, 그러나 closed-source, gradient 따로.
  • Fortran / C++ — 빠르지만 hand-written gradient.
  • Numba — JIT, 그러나 미분 없음.

JAX 가 합치는 두 가지:

  1. 임의 numerical 코드의 자동 미분: ODE 풀이, 시뮬레이션, sampling — 모든 코드 미분 가능.
  2. 가속기 활용: GPU/TPU 위에서 동일한 Python 코드.

이 둘이 합쳐지면 — 새 카테고리의 알고리즘이 가능해져.

전형적 use case

1. Differentiable physics — 시뮬레이터를 통해 gradient 흘려, 물리 시스템 학습

def simulate_pendulum(initial_angle, length, dt, n_steps):
    '''단순한 진자 시뮬레이션'''
    theta, omega = initial_angle, 0.0
    for _ in range(n_steps):
        omega += -9.81 / length * jnp.sin(theta) * dt
        theta += omega * dt
    return theta

# 진자가 특정 각도에 도달하도록 길이를 학습
target_angle = 0.5
def loss(length):
    final = simulate_pendulum(0.1, length, 0.01, 100)
    return (final - target_angle) ** 2

# 자동 미분 — 시뮬레이션 통해 gradient 흐름
optimal_length = optimize_with_grad(loss)

2. Differential equations — Diffrax 가 미분 가능한 ODE/SDE solver

from diffrax import diffeqsolve, Tsit5, ODETerm

def lorenz(t, y, args):
    x, y_, z = y
    return jnp.array([
        10 * (y_ - x),
        x * (28 - z) - y_,
        x * y_ - 8/3 * z,
    ])

solution = diffeqsolve(
    ODETerm(lorenz),
    Tsit5(),
    t0=0.0, t1=10.0, dt0=0.01,
    y0=jnp.array([1., 1., 1.]),
)

3. Bayesian inference — NumPyro 가 JAX 위의 Pyro

import numpyro
import numpyro.distributions as dist

def model(data):
    mu = numpyro.sample("mu", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1))
    with numpyro.plate("data", len(data)):
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)

# NUTS sampler — JAX-jit 으로 자동 가속
mcmc = numpyro.infer.MCMC(numpyro.infer.NUTS(model), num_samples=1000)
mcmc.run(rng_key, data)

4. Optimization — JAXopt, Lineax 같은 라이브러리

import jaxopt

# 비선형 least squares — 자동 미분 + GPU
solver = jaxopt.LevenbergMarquardt(residual_fun=residuals)
result = solver.run(init_params, data=observations)

5. Probabilistic programming — TensorFlow Probability JAX backend

6. Reinforcement learning — Brax (differentiable physics), MJX (MuJoCo)

7. Quantum computing — qujax, PennyLane JAX backend

8. Computational chemistry — JAX-MD (molecular dynamics)

🔬 미분 가능한 시뮬레이션의 의의

전통적 과학 — observation 으로부터 model parameter 추정. JAX-시대 — 시뮬레이션 자체가 미분 가능. 그래서 — gradient descent 로 parameter 학습 + 실험 설계 + sensitivity analysis 가 한 framework. "differentiable everything" 의 패러다임 — 박사 논문 한 권의 가치.

이 quest 의 다음 lesson 들 — Diffrax (ODE), NumPyro (Bayesian), Brax/JAX-MD (시뮬레이션). 각각 — JAX 에코시스템의 한 영역.

Code

import jax
import jax.numpy as jnp

# Example: differentiable physics simulation
def simulate_spring(k, m, x0, v0, dt, num_steps):
    """Simulate a damped spring system: m*x'' + 0.1*x' + k*x = 0"""
    def step(state, _):
        x, v = state
        a = (-k * x - 0.1 * v) / m  # spring force + damping
        v_new = v + a * dt
        x_new = x + v_new * dt
        return (x_new, v_new), x_new

    init_state = (x0, v0)
    _, trajectory = jax.lax.scan(step, init_state, None, length=num_steps)
    return trajectory

# Simulate
traj = simulate_spring(k=2.0, m=1.0, x0=1.0, v0=0.0, dt=0.01, num_steps=1000)

# Gradient: how does the final position change with spring constant?
@jax.jit
def final_position(k):
    return simulate_spring(k, 1.0, 1.0, 0.0, 0.01, 1000)[-1]

dk = jax.grad(final_position)(2.0)
print(f"d(final_pos)/dk = {dk:.6f}")

# Vectorize: simulate 100 different spring constants at once
ks = jnp.linspace(0.5, 5.0, 100)
all_trajectories = jax.vmap(lambda k: simulate_spring(k, 1.0, 1.0, 0.0, 0.01, 1000))(ks)
print(f"Batch trajectories shape: {all_trajectories.shape}")  # (100, 1000)

External links

Exercise

'differentiable + accelerated' 가 classical scientific Python (NumPy/SciPy/Numba) 을 이기는 5 분야 list. 자기가 쓸 거 1 개 + 3 줄 정당화. list 보다 framing 이 더 중요.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.