가장 멀리 갈 수 있는 합성 — gradient 를 또 미분. meta-learning, hyperparameter 최적화, model-agnostic 알고리즘에서 등장.
가장 단순 — 2 차 미분
import jax
import jax.numpy as jnp
def f(x):
return x ** 4
# 1차: 4x³
print(jax.grad(f)(2.0)) # 32
# 2차: 12x²
print(jax.grad(jax.grad(f))(2.0)) # 48
# 3차: 24x
print(jax.grad(jax.grad(jax.grad(f)))(2.0)) # 48
합성이 자유. 각 grad 가 새 함수를 돌려주니까.
Hessian-vector product
큰 model 의 Hessian 직접 계산은 — params 가 N 개면 N² 메모리. 대신 — Hessian 과 vector 의 곱 (HVP) 만 계산:
def hvp(f, x, v):
'''f 의 Hessian 과 vector v 의 곱'''
return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x)
# 사용
def loss(x):
return jnp.sum(x ** 4)
x = jnp.array([1., 2., 3.])
v = jnp.array([1., 0., 0.])
print(hvp(loss, x, v)) # H @ v
Newton's method, Conjugate Gradient, K-FAC 같은 second-order optimizer 의 building block.
Meta-learning: MAML
Model-Agnostic Meta-Learning — "여러 task 에서 빠르게 적응 가능한 initial parameter 찾기".
def task_loss(params, task_data):
x, y = task_data
pred = model_apply(params, x)
return jnp.mean((pred - y) ** 2)
def maml_inner_step(params, task, lr=0.1):
'''단일 task 에 대해 1 step 업데이트'''
grads = jax.grad(task_loss)(params, task)
return jax.tree.map(lambda p, g: p - lr * g, params, grads)
def maml_outer_loss(meta_params, tasks):
'''meta-learning objective'''
total = 0.0
for task in tasks:
# support set 에서 1 step 학습
adapted = maml_inner_step(meta_params, task["support"])
# query set 에서 평가 — 이게 meta-loss
total += task_loss(adapted, task["query"])
return total / len(tasks)
# meta-gradient — outer loss 의 meta_params 에 대한 gradient
# 안에 grad (inner) 가 있고, 그 위로 또 grad (outer) — 2 차 미분
@jax.jit
def meta_step(meta_params, tasks, meta_lr=0.001):
grads = jax.grad(maml_outer_loss)(meta_params, tasks)
return jax.tree.map(lambda p, g: p - meta_lr * g, meta_params, grads)
JAX 가 — 이 grad-of-grad 패턴을 — 추가 코드 한 줄도 없이 자동 처리. 다른 framework 에서는 — explicit second-order autodiff 작성, training loop 변경, gradient hooks. JAX 는 — jax.grad(jax.grad(...)).
Implicit differentiation — 더 깊은 패턴
고정점 (fixed point) 의 gradient — 별도 trick:
# f(x*, theta) = 0 의 고정점 x* (theta 의 함수)
# implicit function theorem: dx*/dtheta = -(df/dx)^-1 (df/dtheta)
def implicit_solver(f, theta, x0):
'''f(x, theta) = 0 의 fixed point — 자동 미분 가능'''
# ... fixed-point 반복으로 x* 찾기
return x_star
# JAX 의 jaxopt 라이브러리가 이 패턴을 깔끔히 wrap
RL 의 implicit reward modeling, iterative optimization, hyperparameter tuning 등 — 모두 같은 핵심 trick.
🌌 grad-of-grad 의 의의
JAX 가 PyTorch 보다 메타 학습 / 과학 계산 분야에서 더 사랑받는 이유 — 이런 자유로운 합성이 실제로 작동. 박사 논문이었던 알고리즘이 — 학생이 1 일 안에 실험 가능. jax.grad(jax.grad(jax.grad(f))) 가 그냥 작동하는 framework — 다른 데서는 hack 또는 별도 라이브러리.
주의 — 깊은 grad-of-grad 는 메모리 / 시간 비용이 커. 4 차 미분쯤 가면 일반적으론 불필요. 2 차 (Hessian, MAML) 까지가 흔한 sweet spot.