C.W.K.
Stream
Lesson 01 of 05 · published

grad 의 일: 함수 in, gradient 함수 out

~8 min · grad, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

jax.grad 는 higher-order function — 함수를 받아서 새 함수를 돌려줘. 새 함수는 같은 인자를 받지만, 결과는 원래 함수의 gradient.

import jax
import jax.numpy as jnp

def f(x):
    return x ** 2 + 3 * x + 5

# grad(f) 는 새로운 함수
df = jax.grad(f)

print(f(2.0))    # 15.0
print(df(2.0))   # 7.0  (정답: 2x + 3 = 7)
print(df(0.0))   # 3.0
print(df(-1.0))  # 1.0

한 가지 규칙 — grad 가 받는 함수는 scalar 를 반환해야 해. 여러 출력을 가진 함수의 gradient 가 뭔지 — 수학적으로는 Jacobian — JAX 는 별도 함수 jax.jacrev, jax.jacfwd 로 처리.

def g(x):
    return jnp.array([x ** 2, x ** 3])

# jax.grad(g)(2.0)  # ❌ TypeError: grad requires scalar output
jacobian = jax.jacrev(g)
print(jacobian(2.0))  # [4., 12.]  ← d/dx [x², x³]

여러 인자에 대한 gradient — default 로 첫 번째 인자에만:

def loss(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

# 첫 번째 인자 (params) 만
g = jax.grad(loss)(params, x, y)

# 명시적
g = jax.grad(loss, argnums=0)(params, x, y)

# x 에 대한 gradient
g_x = jax.grad(loss, argnums=1)(params, x, y)

# 여러 인자 동시
g_p, g_x = jax.grad(loss, argnums=(0, 1))(params, x, y)

중요한 점 — params 가 array 일 수도, dict 일 수도, 임의 pytree 일 수도 있어. grad 는 같은 모양의 pytree gradient 를 돌려줘:

params = {
    "W1": jnp.zeros((10, 20)),
    "b1": jnp.zeros(20),
    "W2": jnp.zeros((20, 5)),
}

def loss(params, x, y):
    h = jnp.tanh(x @ params["W1"] + params["b1"])
    pred = h @ params["W2"]
    return jnp.mean((pred - y) ** 2)

grads = jax.grad(loss)(params, x, y)
# grads 는 같은 dict 구조: {"W1": ..., "b1": ..., "W2": ...}

🌿 functional 미분의 우아함

PyTorch: tensor.requires_grad_(), loss.backward(), 그러면 .grad 가 tensor 에 magic 처럼 부착. JAX: g = jax.grad(loss)(params, x, y) — 입력 in, gradient out. 함수형. 어디 magic state 도 없고, 청소할 .zero_grad() 도 없고, 모든 게 보임.

한 가지 더 — gradient 가 정확히 무얼 의미하는지: 함수의 출력을 1 만큼 늘리려면 입력을 어느 방향으로 얼마나 움직여야 하는가. SGD 는 그 반대 방향으로 움직임. 직관 잡고 있으면 — 수학이 안 헷갈려.

Code

import jax
import jax.numpy as jnp

# A scalar-valued function
def f(x):
    return x ** 2

# grad(f) returns a NEW function that computes df/dx
df = jax.grad(f)

print(f(3.0))    # 9.0
print(df(3.0))   # 6.0 (derivative of x^2 is 2x, evaluated at x=3)

# df is a regular Python function — you can call it, JIT it, etc.
print(df(5.0))   # 10.0
import jax
import jax.numpy as jnp

# Multi-variable function
def loss(params, x, y):
    w, b = params
    pred = jnp.dot(x, w) + b
    return jnp.mean((pred - y) ** 2)

# grad differentiates with respect to the FIRST argument by default
grad_fn = jax.grad(loss)

# Call it: returns gradient with same structure as params
w = jnp.array([1.0, 2.0])
b = jnp.array(0.0)
params = (w, b)
x = jnp.array([[1.0, 0.5], [0.3, 0.8]])
y = jnp.array([1.0, 0.5])

grads = grad_fn(params, x, y)
print(type(grads))       # tuple — same structure as params!
print(grads[0].shape)    # (2,) — gradient of w
print(grads[1].shape)    # () — gradient of b

External links

Exercise

f(x, y) = sin(x) * cos(y) 의 gradient 손으로 도출. jax.grad(f, argnums=(0, 1))(x, y) 호출 후 수치적 비교. 일치 = argnums 이해의 sanity check. 그 다음 일부러 깨기: scalar 아닌 output 넣고 error 관찰.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.