gradient 자체를 또 미분하고 싶을 때 — JAX 는 자유롭게 합성. grad(grad(f)) 가 그냥 작동.
import jax
import jax.numpy as jnp
def f(x):
return x ** 4
# 1차 미분: 4x³
print(jax.grad(f)(2.0)) # 32.0
# 2차 미분: 12x²
print(jax.grad(jax.grad(f))(2.0)) # 48.0
# 3차 미분: 24x
print(jax.grad(jax.grad(jax.grad(f)))(2.0)) # 48.0
실용 예 — Newton's method (convergence 빠름, hessian 사용):
def f(x):
return x ** 3 - 5 * x + 2
f_prime = jax.grad(f)
f_double_prime = jax.grad(f_prime)
x = 0.5
for _ in range(10):
x = x - f_prime(x) / f_double_prime(x)
print(f"근사근: {x}")
Jacobian — vector input/output 의 미분
scalar 함수에 grad. vector 함수엔 Jacobian:
def g(x): # R^3 → R^2
return jnp.array([x[0] * x[1], x[1] ** 2 + x[2]])
x = jnp.array([1.0, 2.0, 3.0])
# Jacobian: 2x3 행렬, J[i,j] = ∂g_i / ∂x_j
J = jax.jacrev(g)(x)
print(J)
# [[2., 1., 0.], ∂(x0*x1)/∂x = [x1, x0, 0]
# [0., 4., 1.]] ∂(x1²+x2)/∂x = [0, 2*x1, 1]
jacrev vs jacfwd — reverse vs forward mode.
- jacrev: 입력 차원 ≫ 출력 차원일 때 효율적. (예: 학습에서 params 1M, loss 1 개)
- jacfwd: 출력 차원 ≫ 입력 차원일 때 효율적. (예: input 3 차원에서 output 100 차원)
scalar loss 의 gradient 는 — 사실 jacrev(loss) 와 같음. jax.grad 는 — scalar output 일 때 jacrev 를 호출하면서 squeeze.
Hessian — 2 차 미분 행렬
def f(x):
return x[0] ** 2 + x[1] ** 2 + x[0] * x[1]
x = jnp.array([1.0, 2.0])
# Hessian = grad of grad, 또는 jacobian of grad
H = jax.hessian(f)(x)
print(H)
# [[2., 1.],
# [1., 2.]]
# 동등 표현
H_alt = jax.jacrev(jax.grad(f))(x)
H_alt2 = jax.jacfwd(jax.grad(f))(x)
🧮 forward vs reverse mode 의 차이
autodiff 는 두 modal: forward (input 측에서 derivatives 누적), reverse (output 측에서 backward). 모든 ML 학습은 reverse — N 개의 input (params) 에 대한 1 개의 output (loss). reverse mode 가 N 번의 forward 와 같은 cost. forward mode 는 1 번의 backward 와 같은 cost. M 개 output ≫ N 개 input 인 경우 jacfwd 가 빨라. 알면 — 큰 모델의 inverse 문제, sensitivity analysis 에서 유용.
Hessian 은 N x N 이라 큰 모델에선 직접 계산 안 함. Hessian-vector product (jax.jvp, jax.vjp) 로 효율 계산. second-order optimization (LBFGS, K-FAC, natural gradient) 에 등장.