C.W.K.
Stream
Lesson 04 of 05 · published

grad 를 jit 와 vmap 와 합성

~9 min · grad, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 의 진짜 매력 — 변환들이 자유롭게 합성됨. jax.jit(jax.grad(f)), jax.vmap(jax.grad(f)), jax.jit(jax.vmap(jax.grad(f))) 다 OK.

1. jit + grad — 학습 step 의 정석

@jax.jit  # value_and_grad 결과를 compile
def train_step(params, x, y):
    def loss_fn(p):
        return jnp.mean((jnp.dot(x, p) - y) ** 2)
    loss, grads = jax.value_and_grad(loss_fn)(params)
    new_params = params - 0.01 * grads
    return new_params, loss

# 첫 호출 — compile (느림)
# 이후 호출 — cache (빠름)
for _ in range(1000):
    params, loss = train_step(params, x, y)

2. vmap + grad — per-example gradient

표준 grad 는 batch 의 mean loss 의 gradient (즉 batch-averaged). per-example gradient — 각 example 의 gradient 를 따로 — 가 필요한 경우가 많음 (gradient noise 분석, differential privacy, influence function).

def per_example_loss(params, x, y):
    '''단일 example 의 loss'''
    pred = jnp.dot(x, params)
    return (pred - y) ** 2

# vmap 으로 batch 처리 — 단, gradient 는 per-example
per_example_grad_fn = jax.vmap(
    jax.grad(per_example_loss),
    in_axes=(None, 0, 0),  # params 는 broadcast, x/y 는 batch
)

# x: (B, D), y: (B,), params: (D,)
per_grads = per_example_grad_fn(params, x, y)  # shape: (B, D)

각 example 마다 gradient vector 가 따로 — 그런데 표준 batch grad 보다 별로 안 느림. 단순한 Python for-loop 보다 100x 빠름.

3. jit + vmap + grad — 합성의 정점

@jax.jit
def per_example_grads(params, batch_x, batch_y):
    return jax.vmap(
        jax.grad(per_example_loss),
        in_axes=(None, 0, 0),
    )(params, batch_x, batch_y)

g_each = per_example_grads(params, x, y)  # (B, D)
g_mean = g_each.mean(0)                    # batch-averaged
g_var = g_each.var(0)                      # gradient variance

이런 패턴이 — PyTorch 에선 trick 이나 hook 으로만 가능했던 일. JAX 에선 한 줄.

4. grad + grad — 메타 학습

def inner_loss(params, x, y):
    return jnp.mean((jnp.dot(x, params) - y) ** 2)

def outer_objective(initial_params, task_data):
    '''첫 번째 inner step 후의 loss'''
    x_train, y_train, x_val, y_val = task_data
    inner_grad = jax.grad(inner_loss)(initial_params, x_train, y_train)
    adapted = initial_params - 0.1 * inner_grad
    return inner_loss(adapted, x_val, y_val)

# meta-gradient: outer obj 의 initial_params 에 대한 gradient
# (inner grad 를 통해 흐르는 gradient — 두 번 미분)
meta_grad = jax.grad(outer_objective)(meta_params, task_data)

MAML, Reptile 같은 메타 학습 알고리즘이 — JAX 에서는 30 줄 이하로.

🎼 합성이 곧 표현력

JAX 는 의도적으로 작은 primitive (jit, grad, vmap, pmap) 만 제공. 그러나 합성하면 — PyTorch 에서 별도 라이브러리 / hooks / functorch 로 가능했던 거의 모든 패턴을 표현할 수 있어. 외워야 할 API 가 적은데 가능한 일이 더 많아. 이게 functional 디자인의 약속이야.

합성의 순서가 중요할 수 있어. jit(vmap(grad(f))) vs vmap(jit(grad(f))) — 결과는 같지만 compile/cache 동작이 다름. 보통 — jit 을 가장 바깥에. jit(vmap(grad(f))) 가 표준 패턴.

Code

import jax
import jax.numpy as jnp

def loss_fn(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

# Compose: differentiate, then compile
fast_grad = jax.jit(jax.grad(loss_fn))

params = jnp.zeros(3)
x = jax.random.normal(jax.random.PRNGKey(0), (100, 3))
y = jax.random.normal(jax.random.PRNGKey(1), (100,))

# Fast compiled gradient computation
grads = fast_grad(params, x, y)
import jax
import jax.numpy as jnp

def single_loss(params, x, y):
    """Loss for ONE example."""
    pred = jnp.dot(x, params)
    return (pred - y) ** 2

# grad: gradient for one example
grad_single = jax.grad(single_loss)

# vmap(grad): gradient for each example independently
per_example_grad = jax.vmap(grad_single, in_axes=(None, 0, 0))

# jit it for speed
fast_per_example_grad = jax.jit(per_example_grad)

params = jnp.array([1.0, 2.0, 3.0])
X = jax.random.normal(jax.random.PRNGKey(0), (64, 3))  # 64 examples
y = jax.random.normal(jax.random.PRNGKey(1), (64,))

grads = fast_per_example_grad(params, X, y)
print(grads.shape)  # (64, 3) — one gradient per example!
import jax
import jax.numpy as jnp

def loss_with_target_network(params, target_params, x):
    """Common in RL: target network should not receive gradients."""
    pred = jnp.dot(x, params)
    # Stop gradient on target — treat as constant
    target = jax.lax.stop_gradient(jnp.dot(x, target_params))
    return jnp.mean((pred - target) ** 2)

# Gradient only flows to params, not target_params
grad_fn = jax.grad(loss_with_target_network)
grads = grad_fn(
    jnp.array([1.0, 2.0]),      # params — gets gradients
    jnp.array([0.5, 1.5]),      # target_params — no gradients
    jnp.array([[1.0, 0.5]])     # x
)
print(grads)  # Gradient with respect to params only

External links

Exercise

loss_fn(params, x, y). per-example gradient 를 jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0)) 로. jit 추가. batch=512 에서 측정. 한 example 마다 grad 호출하는 Python loop 와 비교. speedup 이 — 합성이 중요한 이유.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.