JAX 의 진짜 매력 — 변환들이 자유롭게 합성됨. jax.jit(jax.grad(f)), jax.vmap(jax.grad(f)), jax.jit(jax.vmap(jax.grad(f))) 다 OK.
1. jit + grad — 학습 step 의 정석
@jax.jit # value_and_grad 결과를 compile
def train_step(params, x, y):
def loss_fn(p):
return jnp.mean((jnp.dot(x, p) - y) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(params)
new_params = params - 0.01 * grads
return new_params, loss
# 첫 호출 — compile (느림)
# 이후 호출 — cache (빠름)
for _ in range(1000):
params, loss = train_step(params, x, y)
2. vmap + grad — per-example gradient
표준 grad 는 batch 의 mean loss 의 gradient (즉 batch-averaged). per-example gradient — 각 example 의 gradient 를 따로 — 가 필요한 경우가 많음 (gradient noise 분석, differential privacy, influence function).
def per_example_loss(params, x, y):
'''단일 example 의 loss'''
pred = jnp.dot(x, params)
return (pred - y) ** 2
# vmap 으로 batch 처리 — 단, gradient 는 per-example
per_example_grad_fn = jax.vmap(
jax.grad(per_example_loss),
in_axes=(None, 0, 0), # params 는 broadcast, x/y 는 batch
)
# x: (B, D), y: (B,), params: (D,)
per_grads = per_example_grad_fn(params, x, y) # shape: (B, D)
각 example 마다 gradient vector 가 따로 — 그런데 표준 batch grad 보다 별로 안 느림. 단순한 Python for-loop 보다 100x 빠름.
3. jit + vmap + grad — 합성의 정점
@jax.jit
def per_example_grads(params, batch_x, batch_y):
return jax.vmap(
jax.grad(per_example_loss),
in_axes=(None, 0, 0),
)(params, batch_x, batch_y)
g_each = per_example_grads(params, x, y) # (B, D)
g_mean = g_each.mean(0) # batch-averaged
g_var = g_each.var(0) # gradient variance
이런 패턴이 — PyTorch 에선 trick 이나 hook 으로만 가능했던 일. JAX 에선 한 줄.
4. grad + grad — 메타 학습
def inner_loss(params, x, y):
return jnp.mean((jnp.dot(x, params) - y) ** 2)
def outer_objective(initial_params, task_data):
'''첫 번째 inner step 후의 loss'''
x_train, y_train, x_val, y_val = task_data
inner_grad = jax.grad(inner_loss)(initial_params, x_train, y_train)
adapted = initial_params - 0.1 * inner_grad
return inner_loss(adapted, x_val, y_val)
# meta-gradient: outer obj 의 initial_params 에 대한 gradient
# (inner grad 를 통해 흐르는 gradient — 두 번 미분)
meta_grad = jax.grad(outer_objective)(meta_params, task_data)
MAML, Reptile 같은 메타 학습 알고리즘이 — JAX 에서는 30 줄 이하로.
🎼 합성이 곧 표현력
JAX 는 의도적으로 작은 primitive (jit, grad, vmap, pmap) 만 제공. 그러나 합성하면 — PyTorch 에서 별도 라이브러리 / hooks / functorch 로 가능했던 거의 모든 패턴을 표현할 수 있어. 외워야 할 API 가 적은데 가능한 일이 더 많아. 이게 functional 디자인의 약속이야.
합성의 순서가 중요할 수 있어. jit(vmap(grad(f))) vs vmap(jit(grad(f))) — 결과는 같지만 compile/cache 동작이 다름. 보통 — jit 을 가장 바깥에. jit(vmap(grad(f))) 가 표준 패턴.