가장 자주 쓰는 vmap+grad 패턴 — per-example gradient. 표준 학습은 batch-mean gradient 를 쓰지만, 가끔 example 마다 gradient 를 따로 보고 싶을 때가 있어 (gradient noise scale, differential privacy, influence functions).
import jax
import jax.numpy as jnp
# 단일 example 의 loss
def loss_one(params, x, y):
'''x: (D,), y: (), params: (D,)'''
pred = jnp.dot(x, params)
return (pred - y) ** 2
# grad — 단일 example 의 gradient
g_one = jax.grad(loss_one)
grad_for_one_example = g_one(params, x_single, y_single) # shape: (D,)
# vmap — batch 의 모든 example 의 gradient
g_each = jax.vmap(g_one, in_axes=(None, 0, 0))
per_grads = g_each(params, batch_x, batch_y) # shape: (B, D)
# 비교: 표준 batch loss 의 gradient
def batch_loss(params, X, Y):
pred = X @ params
return jnp.mean((pred - Y) ** 2)
batch_grad = jax.grad(batch_loss)(params, batch_x, batch_y) # shape: (D,)
# 검증: per_grads.mean(0) ≈ batch_grad
print(jnp.allclose(per_grads.mean(0), batch_grad)) # True (거의)
이게 PyTorch 에서 hook 이나 functorch 없이는 어려웠던 일. JAX 에서는 한 줄.
활용 1: Gradient 분포 분석
@jax.jit
def per_example_grads(params, X, Y):
return jax.vmap(jax.grad(loss_one), in_axes=(None, 0, 0))(params, X, Y)
g_each = per_example_grads(params, X_batch, Y_batch) # (B, D)
# gradient norm 분포
norms = jnp.linalg.norm(g_each, axis=1) # (B,)
print(f"평균 norm: {norms.mean():.4f}")
print(f"최대 norm: {norms.max():.4f} (outlier 가능성)")
print(f"norm std: {norms.std():.4f}")
# outlier example 식별
outlier_idx = jnp.argmax(norms)
print(f"가장 큰 gradient 의 example: {outlier_idx}, norm={norms[outlier_idx]}")
실전 — 학습이 불안정할 때 outlier example 을 찾는 게 첫 디버깅 단계.
활용 2: Differential Privacy (DP-SGD)
def dp_sgd_step(params, X, Y, lr, clip_norm, noise_scale, key):
# per-example gradients
g_each = jax.vmap(jax.grad(loss_one), in_axes=(None, 0, 0))(params, X, Y)
# 각 example 의 norm clipping
g_each = jax.vmap(
lambda g: g * jnp.minimum(1, clip_norm / jnp.linalg.norm(g))
)(g_each)
# 평균 + noise
g_clipped_mean = g_each.mean(0)
noise = jax.random.normal(key, g_clipped_mean.shape) * noise_scale
return params - lr * (g_clipped_mean + noise)
DP-SGD — 학술 알고리즘이 30 줄 이하. 다른 framework 에선 별도 라이브러리.
활용 3: Influence Function
"이 한 example 을 빼면 model 이 어떻게 바뀔까?" 의 근사. per-example gradient 가 building block.
# 매우 간단화된 influence
def influence_score(params, X, Y, x_test, y_test):
g_each = jax.vmap(jax.grad(loss_one), in_axes=(None, 0, 0))(params, X, Y)
g_test = jax.grad(loss_one)(params, x_test, y_test)
return -jnp.dot(g_each, g_test) # train example 마다 한 score
🔬 vmap + grad 가 푸는 문제 카테고리
"각 example 별 무언가" 가 필요한 모든 작업 — per-example gradient, per-example Hessian (vmap of grad of grad), per-example feature attribution. JAX 의 합성 덕에 한 줄로 표현 가능. 옛날엔 박사 학위 논문이었던 걸 — 이젠 lab notebook 에서.