Functional autograd, batched transform
torch.func (옛 standalone functorch 의 in-tree 대체) 가 JAX-style function transform 줘: grad, vmap, jacrev, hessian. single example 에 동작하는 코드를 짜면 batch 에 자동 vectorize.
의외로 자주 등장하는 두 실용 use case:
- Per-sample gradient. 표준
.backward()는 batch 의 합쳐진 loss 의 gradient 줘 — 모든 parameter 에 gradient 하나. 일부 research (differential privacy, influence function, GradSAM) 는 모든 individual sample 의 loss 의 gradient 필요.torch.func.vmap(grad(...))가 Python loop 없이 그걸 줘. - Higher-order gradient. Hessian-vector product, second-order optimizer, meta-learning — 모두 gradient 의 gradient 필요.
torch.func.grad(grad(f))가 깔끔히 조합.
정신적 전환
표준 PyTorch: tensor 가 implicit graph state 들고, .backward() 호출. torch.func: input 과 parameter 의 pure function, transform 이 새 pure function 생산. 작지만 실재하는 mindset 변화 — JAX 에 더 가까움, 정상 training 에 살짝 덜 ergonomic, 위 case 엔 훨씬 더 강력.