C.W.K.
Stream
Lesson 03 of 04 · published

vmap + grad: per-example gradient

~9 min · vmap, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

가장 자주 쓰는 vmap+grad 패턴 — per-example gradient. 표준 학습은 batch-mean gradient 를 쓰지만, 가끔 example 마다 gradient 를 따로 보고 싶을 때가 있어 (gradient noise scale, differential privacy, influence functions).

import jax
import jax.numpy as jnp

# 단일 example 의 loss
def loss_one(params, x, y):
    '''x: (D,), y: (), params: (D,)'''
    pred = jnp.dot(x, params)
    return (pred - y) ** 2

# grad — 단일 example 의 gradient
g_one = jax.grad(loss_one)
grad_for_one_example = g_one(params, x_single, y_single)  # shape: (D,)

# vmap — batch 의 모든 example 의 gradient
g_each = jax.vmap(g_one, in_axes=(None, 0, 0))
per_grads = g_each(params, batch_x, batch_y)  # shape: (B, D)

# 비교: 표준 batch loss 의 gradient
def batch_loss(params, X, Y):
    pred = X @ params
    return jnp.mean((pred - Y) ** 2)

batch_grad = jax.grad(batch_loss)(params, batch_x, batch_y)  # shape: (D,)

# 검증: per_grads.mean(0) ≈ batch_grad
print(jnp.allclose(per_grads.mean(0), batch_grad))  # True (거의)

이게 PyTorch 에서 hook 이나 functorch 없이는 어려웠던 일. JAX 에서는 한 줄.

활용 1: Gradient 분포 분석

@jax.jit
def per_example_grads(params, X, Y):
    return jax.vmap(jax.grad(loss_one), in_axes=(None, 0, 0))(params, X, Y)

g_each = per_example_grads(params, X_batch, Y_batch)  # (B, D)

# gradient norm 분포
norms = jnp.linalg.norm(g_each, axis=1)  # (B,)
print(f"평균 norm: {norms.mean():.4f}")
print(f"최대 norm: {norms.max():.4f} (outlier 가능성)")
print(f"norm std: {norms.std():.4f}")

# outlier example 식별
outlier_idx = jnp.argmax(norms)
print(f"가장 큰 gradient 의 example: {outlier_idx}, norm={norms[outlier_idx]}")

실전 — 학습이 불안정할 때 outlier example 을 찾는 게 첫 디버깅 단계.

활용 2: Differential Privacy (DP-SGD)

def dp_sgd_step(params, X, Y, lr, clip_norm, noise_scale, key):
    # per-example gradients
    g_each = jax.vmap(jax.grad(loss_one), in_axes=(None, 0, 0))(params, X, Y)

    # 각 example 의 norm clipping
    g_each = jax.vmap(
        lambda g: g * jnp.minimum(1, clip_norm / jnp.linalg.norm(g))
    )(g_each)

    # 평균 + noise
    g_clipped_mean = g_each.mean(0)
    noise = jax.random.normal(key, g_clipped_mean.shape) * noise_scale
    return params - lr * (g_clipped_mean + noise)

DP-SGD — 학술 알고리즘이 30 줄 이하. 다른 framework 에선 별도 라이브러리.

활용 3: Influence Function

"이 한 example 을 빼면 model 이 어떻게 바뀔까?" 의 근사. per-example gradient 가 building block.

# 매우 간단화된 influence
def influence_score(params, X, Y, x_test, y_test):
    g_each = jax.vmap(jax.grad(loss_one), in_axes=(None, 0, 0))(params, X, Y)
    g_test = jax.grad(loss_one)(params, x_test, y_test)
    return -jnp.dot(g_each, g_test)  # train example 마다 한 score

🔬 vmap + grad 가 푸는 문제 카테고리

"각 example 별 무언가" 가 필요한 모든 작업 — per-example gradient, per-example Hessian (vmap of grad of grad), per-example feature attribution. JAX 의 합성 덕에 한 줄로 표현 가능. 옛날엔 박사 학위 논문이었던 걸 — 이젠 lab notebook 에서.

Code

import jax
import jax.numpy as jnp

def loss_single(params, x, y):
    """Loss for a SINGLE example — keep it simple!"""
    pred = jnp.dot(x, params)
    return (pred - y) ** 2

# Regular gradient: gradient of average loss
def batch_loss(params, X, y):
    return jnp.mean(jax.vmap(loss_single, in_axes=(None, 0, 0))(params, X, y))

avg_grad = jax.grad(batch_loss)

# Per-example gradients: separate gradient for each example
per_example_grad = jax.vmap(jax.grad(loss_single), in_axes=(None, 0, 0))

# Test
params = jnp.array([1.0, 2.0, 3.0])
X = jax.random.normal(jax.random.PRNGKey(0), (32, 3))
y = jax.random.normal(jax.random.PRNGKey(1), (32,))

# Average gradient: shape (3,)
g_avg = avg_grad(params, X, y)
print(f"Average gradient shape: {g_avg.shape}")  # (3,)

# Per-example gradients: shape (32, 3)
g_per = per_example_grad(params, X, y)
print(f"Per-example gradient shape: {g_per.shape}")  # (32, 3)

# Verify: mean of per-example gradients ≈ average gradient
print(f"Match: {jnp.allclose(jnp.mean(g_per, axis=0), g_avg, atol=1e-5)}")
import jax
import jax.numpy as jnp

def loss_single(params, x, y):
    pred = jnp.dot(x, params)
    return (pred - y) ** 2

def dp_gradient(params, X, y, clip_norm=1.0, noise_scale=0.1, key=None):
    """Differentially private gradient computation."""
    # 1. Get per-example gradients
    per_ex_grads = jax.vmap(
        jax.grad(loss_single), in_axes=(None, 0, 0)
    )(params, X, y)

    # 2. Clip each gradient to the norm bound
    grad_norms = jnp.linalg.norm(per_ex_grads, axis=-1, keepdims=True)
    clip_factor = jnp.minimum(1.0, clip_norm / (grad_norms + 1e-8))
    clipped_grads = per_ex_grads * clip_factor

    # 3. Average and add noise
    avg_grad = jnp.mean(clipped_grads, axis=0)
    noise = noise_scale * clip_norm * jax.random.normal(key, avg_grad.shape)
    return avg_grad + noise / X.shape[0]

# Usage
key = jax.random.PRNGKey(42)
params = jnp.array([1.0, 2.0, 3.0])
X = jax.random.normal(jax.random.PRNGKey(0), (64, 3))
y = jax.random.normal(jax.random.PRNGKey(1), (64,))

dp_grad = dp_gradient(params, X, y, key=key)
print(f"DP gradient: {dp_grad}")

External links

Exercise

256 example 의 batch 에서 per-example gradient 계산. mean 이 mean-loss 의 gradient 와 일치 검증. gradient norm histogram plot — 실제 research 에서 쓰는 진단.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.