Tape 가 아니라 composition 으로 하는 autograd
PyTorch 의 autograd 는 네가 op 들 하면서 그 tape 을 기록하고, tape 을 거꾸로 걷는 식으로 동작. MLX (와 JAX) 는 다른 길 — 함수를 받아서 새 함수를 돌려주는 변환 들을 노출. mx.grad(f) 는 사후에 f 의 tape 을 미분 안 해. 호출되면 gradient 를 계산하는 새 함수 를 돌려줘.
작은 구분처럼 들려. 안 그래. 의미는 네 gradient 가 first-class composable 값 — 다른 transform 으로 감쌀 수 있고, JIT-compile 가능, batch 에 vmap 가능. 그 어떤 것도 특별한 "training 모드" 또는 살아 있는 tape 안 필요.
가장 많이 쓸 셋
mx.grad(f) — scalar-값 함수가 주어지면, 첫 인자에 대한 gradient 계산하는 함수 돌려줘. Gradient 원하지만 loss 값 신경 안 쓸 때.
mx.value_and_grad(f) — mx.grad 와 같지만, 한 호출에 loss 값과 gradient 둘 다 돌려줘. Training loop 에서 이거 써. 거의 항상 loss 로깅하고 싶으니까.
mx.vmap(f) — 단일-예제 함수를 batch 함수로 바꿔. grad 와 자연스럽게 compose. 수동 broadcasting 없이 효율적인 batch 계산 원할 때.
Composability 가 포인트
각 transform 이 함수를 받고 함수를 돌려주니까, 쌓을 수 있어. mx.vmap(mx.grad(loss_fn)) 가 batch 에 걸친 per-example gradient 줘. mx.grad(mx.grad(f)) 가 second derivative 줘. mx.compile(mx.value_and_grad(loss_fn)) 가 gradient 계산을 JIT-compile. 다른 framework 들이 이 트릭들을 가능하게 해. MLX 는 명백하게 만들어.
nn.value_and_grad — 모델-인식 변종
Parameter 가진 nn.Module 가 있으면, 보통 함수의 첫 인자가 아니라 모델의 parameter 에 대한 gradient 원해. nn 모듈이 학습용 canonical 패턴인 nn.value_and_grad(model, loss_fn) 노출. Lesson 6 에서 쓸 거야.