C.W.K.
Stream
Lesson 05 of 05 · published

Meta-Learning: Grad-of-Grad 패턴

~9 min · advanced, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

가장 멀리 갈 수 있는 합성 — gradient 를 또 미분. meta-learning, hyperparameter 최적화, model-agnostic 알고리즘에서 등장.

가장 단순 — 2 차 미분

import jax
import jax.numpy as jnp

def f(x):
    return x ** 4

# 1차: 4x³
print(jax.grad(f)(2.0))   # 32

# 2차: 12x²
print(jax.grad(jax.grad(f))(2.0))   # 48

# 3차: 24x
print(jax.grad(jax.grad(jax.grad(f)))(2.0))   # 48

합성이 자유. 각 grad 가 새 함수를 돌려주니까.

Hessian-vector product

큰 model 의 Hessian 직접 계산은 — params 가 N 개면 N² 메모리. 대신 — Hessian 과 vector 의 곱 (HVP) 만 계산:

def hvp(f, x, v):
    '''f 의 Hessian 과 vector v 의 곱'''
    return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x)

# 사용
def loss(x):
    return jnp.sum(x ** 4)

x = jnp.array([1., 2., 3.])
v = jnp.array([1., 0., 0.])
print(hvp(loss, x, v))   # H @ v

Newton's method, Conjugate Gradient, K-FAC 같은 second-order optimizer 의 building block.

Meta-learning: MAML

Model-Agnostic Meta-Learning — "여러 task 에서 빠르게 적응 가능한 initial parameter 찾기".

def task_loss(params, task_data):
    x, y = task_data
    pred = model_apply(params, x)
    return jnp.mean((pred - y) ** 2)

def maml_inner_step(params, task, lr=0.1):
    '''단일 task 에 대해 1 step 업데이트'''
    grads = jax.grad(task_loss)(params, task)
    return jax.tree.map(lambda p, g: p - lr * g, params, grads)

def maml_outer_loss(meta_params, tasks):
    '''meta-learning objective'''
    total = 0.0
    for task in tasks:
        # support set 에서 1 step 학습
        adapted = maml_inner_step(meta_params, task["support"])
        # query set 에서 평가 — 이게 meta-loss
        total += task_loss(adapted, task["query"])
    return total / len(tasks)

# meta-gradient — outer loss 의 meta_params 에 대한 gradient
# 안에 grad (inner) 가 있고, 그 위로 또 grad (outer) — 2 차 미분
@jax.jit
def meta_step(meta_params, tasks, meta_lr=0.001):
    grads = jax.grad(maml_outer_loss)(meta_params, tasks)
    return jax.tree.map(lambda p, g: p - meta_lr * g, meta_params, grads)

JAX 가 — 이 grad-of-grad 패턴을 — 추가 코드 한 줄도 없이 자동 처리. 다른 framework 에서는 — explicit second-order autodiff 작성, training loop 변경, gradient hooks. JAX 는 — jax.grad(jax.grad(...)).

Implicit differentiation — 더 깊은 패턴

고정점 (fixed point) 의 gradient — 별도 trick:

# f(x*, theta) = 0 의 고정점 x* (theta 의 함수)
# implicit function theorem: dx*/dtheta = -(df/dx)^-1 (df/dtheta)

def implicit_solver(f, theta, x0):
    '''f(x, theta) = 0 의 fixed point — 자동 미분 가능'''
    # ... fixed-point 반복으로 x* 찾기
    return x_star

# JAX 의 jaxopt 라이브러리가 이 패턴을 깔끔히 wrap

RL 의 implicit reward modeling, iterative optimization, hyperparameter tuning 등 — 모두 같은 핵심 trick.

🌌 grad-of-grad 의 의의

JAX 가 PyTorch 보다 메타 학습 / 과학 계산 분야에서 더 사랑받는 이유 — 이런 자유로운 합성이 실제로 작동. 박사 논문이었던 알고리즘이 — 학생이 1 일 안에 실험 가능. jax.grad(jax.grad(jax.grad(f))) 가 그냥 작동하는 framework — 다른 데서는 hack 또는 별도 라이브러리.

주의 — 깊은 grad-of-grad 는 메모리 / 시간 비용이 커. 4 차 미분쯤 가면 일반적으론 불필요. 2 차 (Hessian, MAML) 까지가 흔한 sweet spot.

Code

import jax
import jax.numpy as jnp

# Second-order derivatives are trivial
def f(x):
    return jnp.sin(x) * x ** 2

# First derivative
df = jax.grad(f)
print(df(1.0))  # cos(1)*1 + sin(1)*2 ≈ 2.22

# Second derivative (Hessian for scalar functions)
d2f = jax.grad(jax.grad(f))
print(d2f(1.0))  # ≈ -0.18

# Hessian for vector functions
def g(x):
    return jnp.sum(x ** 3)

hessian = jax.hessian(g)
print(hessian(jnp.array([1.0, 2.0, 3.0])))
# [[6., 0., 0.],
#  [0., 12., 0.],
#  [0., 0., 18.]]
def maml_loss(meta_params, tasks, inner_lr=0.01, inner_steps=5):
    """MAML outer loss: how well do inner-loop-adapted params perform?"""
    total_loss = 0.0

    for task_train, task_test in tasks:
        # Inner loop: adapt to task using gradient descent
        params = meta_params
        for _ in range(inner_steps):
            train_loss = compute_loss(params, *task_train)
            grads = jax.grad(compute_loss)(params, *task_train)
            params = jax.tree.map(
                lambda p, g: p - inner_lr * g, params, grads)

        # Outer loss: evaluate adapted params on test data
        test_loss = compute_loss(params, *task_test)
        total_loss += test_loss

    return total_loss / len(tasks)

# The magic: differentiate through the inner loop!
meta_grads = jax.grad(maml_loss)(meta_params, tasks)
# This computes second-order gradients (gradient through gradient descent)
# In PyTorch, this requires create_graph=True and careful management.
# In JAX, it just works — grad(grad(...)) composes naturally.
# Practical example: memory-efficient Transformer with remat + scan
def create_efficient_transformer(d_model, num_heads, d_ff, num_layers, rngs):
    """Create a transformer that uses scan + remat for efficiency."""
    # Initialize one block's parameters
    def init_block(rngs):
        return TransformerBlock(d_model, num_heads, d_ff, rngs)

    blocks = [init_block(rngs) for _ in range(num_layers)]

    def forward(x):
        for block in blocks:
            # Remat: recompute activations in backward pass (saves memory)
            x = jax.checkpoint(lambda b, x: b(x), block, x)
        return x

    return forward, blocks

External links

Exercise

MAML 의 inner-loop update 를 30 줄 안에 구현: 1 inner SGD step 통한 meta-grad. 2-task synthetic problem 에서 검증. grad-of-grad 패턴이 — quest 전체의 conceptual peak — 거저 주어진 게 아니라 얻은 거.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.