C.W.K.
Stream
Lesson 05 of 07 · published

mx.grad — 그리고 function transform 가족

~16 min · autograd, mx.grad, vmap, value-and-grad

Level 0Curious
0 XP0/51 lessons0/15 achievements
0/100 XP to next level100 XP to go0% complete

Tape 가 아니라 composition 으로 하는 autograd

PyTorch 의 autograd 는 네가 op 들 하면서 그 tape 을 기록하고, tape 을 거꾸로 걷는 식으로 동작. MLX (와 JAX) 는 다른 길 — 함수를 받아서 새 함수를 돌려주는 변환 들을 노출. mx.grad(f) 는 사후에 f 의 tape 을 미분 안 해. 호출되면 gradient 를 계산하는 새 함수 를 돌려줘.

작은 구분처럼 들려. 안 그래. 의미는 네 gradient 가 first-class composable 값 — 다른 transform 으로 감쌀 수 있고, JIT-compile 가능, batch 에 vmap 가능. 그 어떤 것도 특별한 "training 모드" 또는 살아 있는 tape 안 필요.

가장 많이 쓸 셋

mx.grad(f) — scalar-값 함수가 주어지면, 첫 인자에 대한 gradient 계산하는 함수 돌려줘. Gradient 원하지만 loss 값 신경 안 쓸 때.

mx.value_and_grad(f)mx.grad 와 같지만, 한 호출에 loss 값과 gradient 둘 다 돌려줘. Training loop 에서 이거 써. 거의 항상 loss 로깅하고 싶으니까.

mx.vmap(f) — 단일-예제 함수를 batch 함수로 바꿔. grad 와 자연스럽게 compose. 수동 broadcasting 없이 효율적인 batch 계산 원할 때.

Composability 가 포인트

각 transform 이 함수를 받고 함수를 돌려주니까, 쌓을 수 있어. mx.vmap(mx.grad(loss_fn)) 가 batch 에 걸친 per-example gradient 줘. mx.grad(mx.grad(f)) 가 second derivative 줘. mx.compile(mx.value_and_grad(loss_fn)) 가 gradient 계산을 JIT-compile. 다른 framework 들이 이 트릭들을 가능하게 해. MLX 는 명백하게 만들어.

nn.value_and_grad — 모델-인식 변종

Parameter 가진 nn.Module 가 있으면, 보통 함수의 첫 인자가 아니라 모델의 parameter 에 대한 gradient 원해. nn 모듈이 학습용 canonical 패턴인 nn.value_and_grad(model, loss_fn) 노출. Lesson 6 에서 쓸 거야.

Code

mx.grad — 한 줄짜리·python
import mlx.core as mx

def square(x):
    return x ** 2

grad_sq = mx.grad(square)
print(grad_sq(mx.array(3.0)))   # → array(6, dtype=float32)   d/dx(x^2) = 2x = 6
value_and_grad — training-loop 원시형·python
import mlx.core as mx

def loss(w, x, y):
    pred = w * x
    return ((pred - y) ** 2).mean()

vg = mx.value_and_grad(loss)

w0 = mx.array(1.5)
xs = mx.array([1.0, 2.0, 3.0, 4.0])
ys = mx.array([2.0, 4.0, 6.0, 8.0])      # true relation y = 2x

v, g = vg(w0, xs, ys)
print('loss value:', float(v))            # → 1.875
print('gradient w.r.t. w:', float(g))     # → -7.5  (you'd subtract this in SGD)
vmap — batch 함수 호출·python
import mlx.core as mx

def f(x):
    return x ** 2 + 1

# Vectorize f over the leading axis
batched = mx.vmap(f)

xb = mx.array([1.0, 2.0, 3.0, 4.0])
print(batched(xb))   # → array([2, 5, 10, 17], dtype=float32)

# Composes with grad: per-example gradient across a batch
batched_grad = mx.vmap(mx.grad(f))
print(batched_grad(xb))   # → array([2, 4, 6, 8], dtype=float32)   (= 2x for each)

External links

Exercise

두 인자의 작은 scalar 함수 골라 (예 def f(x, y): return (x - y) ** 2 + x). mx.grad(f, argnums=0)mx.grad(f, argnums=1) 써서 xy 에 대한 partial derivative 계산. 손으로 답 검증. 그 다음 (x, y) 쌍 batch 에 mx.vmap compose. 운동의 포인트는 argnums + composition 이 custom backward pass 박지 않고 완전한 통제 준다고 느끼는 것.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.