C.W.K.
Stream
Lesson 06 of 10 · published

Custom Autograd Function

~12 min · custom, Function, advanced

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

built-in op 가 부족할 때

model 코드의 99% 는 built-in differentiable op 면 충분 — autograd 가 chain rule 로 derivative 조합, backward 직접 안 씀. 나머지 1% 가 torch.autograd.Function 에 손 가는 곳:

  • forward 가 기존 differentiable op 로 안 만들어지는 op 구현 (예: custom CUDA kernel 짰을 때).
  • gradient 재정의 원함 (straight-through estimator — quantization-aware training 과 discrete-output model 의 교과서 예).
  • store 대신 recompute 로 메모리 절약 — 그건 보통 torch.utils.checkpoint 가 더 나아.

계약

custom Function 에 두 static method:

  • forward(ctx, *inputs) — forward 계산. ctx.save_for_backward(...) 로 backward 가 필요한 거 stash.
  • backward(ctx, *grad_outputs) — backward 계산. forward 의 input 당 gradient 하나 반환 (또는 grad 안 필요한 input 엔 None).

forwardbackward 직접 호출 안 함 — MyFunction.apply(...) 호출, graph node 를 적절히 setup.

Code

custom ReLU — 이해용, production 아님·python
import torch

class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)             # stash x for backward
        return x.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_out):
        (x,) = ctx.saved_tensors
        grad_in = grad_out.clone()
        grad_in[x < 0] = 0                    # zero gradient where x < 0
        return grad_in

my_relu = MyReLU.apply

x = torch.randn(5, requires_grad=True)
y = my_relu(x)
y.sum().backward()
print(x.grad)   # zeros where x<0, ones where x>=0
Straight-Through Estimator — gradient 그대로 통과·python
import torch

class STE(torch.autograd.Function):
    """Forward: hard threshold. Backward: pretend it was identity."""
    @staticmethod
    def forward(ctx, x):
        return (x > 0).float()

    @staticmethod
    def backward(ctx, grad_out):
        return grad_out                      # straight through

binarize = STE.apply

# Useful for quantization-aware training:
# the forward step is non-differentiable, but we still need a learning signal
x = torch.randn(4, requires_grad=True)
y = binarize(x)
y.sum().backward()
print(x.grad)   # ones — gradient passed through unchanged
gradcheck 로 derivative sanity-check·python
import torch
from torch.autograd import gradcheck

class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_out):
        (x,) = ctx.saved_tensors
        grad_in = grad_out.clone()
        grad_in[x < 0] = 0
        return grad_in

# Use float64 for numerical stability when gradchecking
x = torch.randn(8, dtype=torch.float64, requires_grad=True)
print(gradcheck(MyReLU.apply, (x,), eps=1e-6, atol=1e-5))
# True if your analytic backward matches finite-difference numerics

External links

Exercise

custom Sigmoid Function 을 forward 와 backward 둘 다로 구현. float64 tensor 에 torch.autograd.gradcheck 로 검증. 그 다음 같은 input 에서 forward 가 torch.sigmoid 와 일치하는지 검증. 의심될 때마다 다시 돌릴 수 있는 Python 파일로 test 저장.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.