jax.grad 는 higher-order function — 함수를 받아서 새 함수를 돌려줘. 새 함수는 같은 인자를 받지만, 결과는 원래 함수의 gradient.
import jax
import jax.numpy as jnp
def f(x):
return x ** 2 + 3 * x + 5
# grad(f) 는 새로운 함수
df = jax.grad(f)
print(f(2.0)) # 15.0
print(df(2.0)) # 7.0 (정답: 2x + 3 = 7)
print(df(0.0)) # 3.0
print(df(-1.0)) # 1.0
한 가지 규칙 — grad 가 받는 함수는 scalar 를 반환해야 해. 여러 출력을 가진 함수의 gradient 가 뭔지 — 수학적으로는 Jacobian — JAX 는 별도 함수 jax.jacrev, jax.jacfwd 로 처리.
def g(x):
return jnp.array([x ** 2, x ** 3])
# jax.grad(g)(2.0) # ❌ TypeError: grad requires scalar output
jacobian = jax.jacrev(g)
print(jacobian(2.0)) # [4., 12.] ← d/dx [x², x³]
여러 인자에 대한 gradient — default 로 첫 번째 인자에만:
def loss(params, x, y):
pred = jnp.dot(x, params)
return jnp.mean((pred - y) ** 2)
# 첫 번째 인자 (params) 만
g = jax.grad(loss)(params, x, y)
# 명시적
g = jax.grad(loss, argnums=0)(params, x, y)
# x 에 대한 gradient
g_x = jax.grad(loss, argnums=1)(params, x, y)
# 여러 인자 동시
g_p, g_x = jax.grad(loss, argnums=(0, 1))(params, x, y)
중요한 점 — params 가 array 일 수도, dict 일 수도, 임의 pytree 일 수도 있어. grad 는 같은 모양의 pytree gradient 를 돌려줘:
params = {
"W1": jnp.zeros((10, 20)),
"b1": jnp.zeros(20),
"W2": jnp.zeros((20, 5)),
}
def loss(params, x, y):
h = jnp.tanh(x @ params["W1"] + params["b1"])
pred = h @ params["W2"]
return jnp.mean((pred - y) ** 2)
grads = jax.grad(loss)(params, x, y)
# grads 는 같은 dict 구조: {"W1": ..., "b1": ..., "W2": ...}
🌿 functional 미분의 우아함
PyTorch: tensor.requires_grad_(), loss.backward(), 그러면 .grad 가 tensor 에 magic 처럼 부착. JAX: g = jax.grad(loss)(params, x, y) — 입력 in, gradient out. 함수형. 어디 magic state 도 없고, 청소할 .zero_grad() 도 없고, 모든 게 보임.
한 가지 더 — gradient 가 정확히 무얼 의미하는지: 함수의 출력을 1 만큼 늘리려면 입력을 어느 방향으로 얼마나 움직여야 하는가. SGD 는 그 반대 방향으로 움직임. 직관 잡고 있으면 — 수학이 안 헷갈려.