대부분의 ML framework 는 객체 중심이야. model object 만들고, 거기에 데이터 흘려보내고, 메서드 호출. JAX 는 정반대로 가. 핵심 아이디어가 composable function transformations — 함수를 인자로 받아서 새로운 변환된 함수를 돌려주는 higher-order 함수.
4 대 변환 — JAX 의 "four pillars":
jax.jit— Compile: 함수를 trace 해서 XLA compiler 로 최적화된 머신 코드로. 첫 호출은 compile 때문에 느리지만, 다음 호출부터는 극적으로 빠름.jax.grad— Differentiate: scalar 반환 함수를 받아서 gradient 계산하는 새 함수를 돌려줌.jax.vmap— Vectorize: 한 example 에 작동하는 함수를 batch 에 작동하도록 변환.jax.pmap— Parallelize: 여러 device (GPU/TPU) 에 함수를 동시에 돌림.
핵심 단어가 composable 이야. 자유롭게 stack 가능:
import jax
import jax.numpy as jnp
def loss_fn(params, x, y):
predictions = jnp.dot(x, params)
return jnp.mean((predictions - y) ** 2)
# 합성: 미분 → compile
fast_grad = jax.jit(jax.grad(loss_fn))
# 또는: per-example gradient → compile
per_example_grads = jax.jit(jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0)))
💡 왜 이게 중요한가
각 변환은 자체로 유용해. 근데 진짜 마법은 합성에서 나와. jit(vmap(grad(f))) — 이걸 PyTorch 에서 하려면 코드 한참 다시 짜야 돼. JAX 에선 한 줄. 변환을 1 등 시민으로 만든 게 — 이게 differential programming 의 패러다임 전환이야.