C.W.K.
Stream
Lesson 03 of 05 · published

Higher-Order Gradients 와 Jacobians

~8 min · grad, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

gradient 자체를 또 미분하고 싶을 때 — JAX 는 자유롭게 합성. grad(grad(f)) 가 그냥 작동.

import jax
import jax.numpy as jnp

def f(x):
    return x ** 4

# 1차 미분: 4x³
print(jax.grad(f)(2.0))  # 32.0

# 2차 미분: 12x²
print(jax.grad(jax.grad(f))(2.0))  # 48.0

# 3차 미분: 24x
print(jax.grad(jax.grad(jax.grad(f)))(2.0))  # 48.0

실용 예 — Newton's method (convergence 빠름, hessian 사용):

def f(x):
    return x ** 3 - 5 * x + 2

f_prime = jax.grad(f)
f_double_prime = jax.grad(f_prime)

x = 0.5
for _ in range(10):
    x = x - f_prime(x) / f_double_prime(x)
print(f"근사근: {x}")

Jacobian — vector input/output 의 미분

scalar 함수에 grad. vector 함수엔 Jacobian:

def g(x):  # R^3 → R^2
    return jnp.array([x[0] * x[1], x[1] ** 2 + x[2]])

x = jnp.array([1.0, 2.0, 3.0])

# Jacobian: 2x3 행렬, J[i,j] = ∂g_i / ∂x_j
J = jax.jacrev(g)(x)
print(J)
# [[2., 1., 0.],   ∂(x0*x1)/∂x = [x1, x0, 0]
#  [0., 4., 1.]]   ∂(x1²+x2)/∂x = [0, 2*x1, 1]

jacrev vs jacfwd — reverse vs forward mode.

  • jacrev: 입력 차원 ≫ 출력 차원일 때 효율적. (예: 학습에서 params 1M, loss 1 개)
  • jacfwd: 출력 차원 ≫ 입력 차원일 때 효율적. (예: input 3 차원에서 output 100 차원)

scalar loss 의 gradient 는 — 사실 jacrev(loss) 와 같음. jax.grad 는 — scalar output 일 때 jacrev 를 호출하면서 squeeze.

Hessian — 2 차 미분 행렬

def f(x):
    return x[0] ** 2 + x[1] ** 2 + x[0] * x[1]

x = jnp.array([1.0, 2.0])

# Hessian = grad of grad, 또는 jacobian of grad
H = jax.hessian(f)(x)
print(H)
# [[2., 1.],
#  [1., 2.]]

# 동등 표현
H_alt = jax.jacrev(jax.grad(f))(x)
H_alt2 = jax.jacfwd(jax.grad(f))(x)

🧮 forward vs reverse mode 의 차이

autodiff 는 두 modal: forward (input 측에서 derivatives 누적), reverse (output 측에서 backward). 모든 ML 학습은 reverse — N 개의 input (params) 에 대한 1 개의 output (loss). reverse mode 가 N 번의 forward 와 같은 cost. forward mode 는 1 번의 backward 와 같은 cost. M 개 output ≫ N 개 input 인 경우 jacfwd 가 빨라. 알면 — 큰 모델의 inverse 문제, sensitivity analysis 에서 유용.

Hessian 은 N x N 이라 큰 모델에선 직접 계산 안 함. Hessian-vector product (jax.jvp, jax.vjp) 로 효율 계산. second-order optimization (LBFGS, K-FAC, natural gradient) 에 등장.

Code

import jax
import jax.numpy as jnp

def f(x):
    return jnp.sin(x)

df = jax.grad(f)         # cos(x)
d2f = jax.grad(df)       # -sin(x)
d3f = jax.grad(d2f)      # -cos(x)
d4f = jax.grad(d3f)      # sin(x)

x = jnp.array(jnp.pi / 4)
print(f"f(x)   = {f(x):.4f}")     # 0.7071 (sin)
print(f"f'(x)  = {df(x):.4f}")    # 0.7071 (cos)
print(f"f''(x) = {d2f(x):.4f}")   # -0.7071 (-sin)
print(f"f'''(x)= {d3f(x):.4f}")   # -0.7071 (-cos)
import jax
import jax.numpy as jnp

def f(x):
    return x[0] ** 2 * x[1] + x[1] ** 3

# Hessian: matrix of second partial derivatives
hessian_fn = jax.hessian(f)
x = jnp.array([1.0, 2.0])
H = hessian_fn(x)
print(H)
# [[ 4.  2.]   d^2f/dx0^2 = 2*x1 = 4,  d^2f/dx0dx1 = 2*x0 = 2
#  [ 2. 12.]]  d^2f/dx1dx0 = 2*x0 = 2,  d^2f/dx1^2  = 6*x1 = 12
import jax
import jax.numpy as jnp

def vector_fn(x):
    """Maps R^3 -> R^2."""
    return jnp.array([x[0] * x[1], x[1] ** 2 + x[2]])

x = jnp.array([1.0, 2.0, 3.0])

# jacrev: Reverse-mode Jacobian (efficient when output dim < input dim)
J_rev = jax.jacrev(vector_fn)(x)
print(J_rev)
# [[2. 1. 0.]    dy0/dx0=x1, dy0/dx1=x0, dy0/dx2=0
#  [0. 4. 1.]]   dy1/dx0=0,  dy1/dx1=2*x1, dy1/dx2=1
print(J_rev.shape)  # (2, 3)

# jacfwd: Forward-mode Jacobian (efficient when input dim < output dim)
J_fwd = jax.jacfwd(vector_fn)(x)
print(jnp.allclose(J_rev, J_fwd))  # True — same result, different algorithm

External links

Exercise

f(x) = sum(x**3) 에 대해 Jacobian (jacrev), Hessian (grad of grad), 분석 답 (3x², 6x) 검증. jacfwd vs jacrev cost — input/output 차원에 맞춰 선택.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.