built-in op 가 부족할 때
model 코드의 99% 는 built-in differentiable op 면 충분 — autograd 가 chain rule 로 derivative 조합, backward 직접 안 씀. 나머지 1% 가 torch.autograd.Function 에 손 가는 곳:
- forward 가 기존 differentiable op 로 안 만들어지는 op 구현 (예: custom CUDA kernel 짰을 때).
- gradient 재정의 원함 (straight-through estimator — quantization-aware training 과 discrete-output model 의 교과서 예).
- store 대신 recompute 로 메모리 절약 — 그건 보통
torch.utils.checkpoint가 더 나아.
계약
custom Function 에 두 static method:
forward(ctx, *inputs)— forward 계산.ctx.save_for_backward(...)로 backward 가 필요한 거 stash.backward(ctx, *grad_outputs)— backward 계산. forward 의 input 당 gradient 하나 반환 (또는 grad 안 필요한 input 엔 None).
forward 나 backward 직접 호출 안 함 — MyFunction.apply(...) 호출, graph node 를 적절히 setup.