C.W.K.
Stream
Lesson 04 of 10 · published

no_grad, inference_mode, detach

~10 min · no_grad, inference, detach

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

'gradient 추적 안 함' 을 말하는 세 방법

대부분 시간 autograd recording 원함. 가끔 안 원함 — inference, evaluation, preprocessing, metric 계산, plot. Recording 은 메모리 비용 (activation 들고 있어야) 과 작은 CPU. 건너뛰는 세 도구:

  • torch.no_grad() — context manager. 블록 안 어떤 op 도 graph 안 만듦. model.eval() 코드 path 의 표준 wrapper.
  • torch.inference_mode() — no_grad 의 더 강한 버전. autograd AND view-tracking 둘 다 비활성, 살짝 더 빠름. 순수 inference 의 추천 선택.
  • tensor.detach() — storage 공유하지만 autograd graph 에서 disconnect 된 새 tensor 반환. 코드 블록 전체 아니라 individual tensor 용 (loss 값 logging, RL target network).

경계: model.eval() vs no_grad()

둘은 관련 있지만 다름:

  • model.eval()module 행동 변경 — Dropout 이 no-op, BatchNorm 이 batch stats 대신 running stats 사용. autograd 비활성화 안 함.
  • torch.no_grad() 가 autograd 비활성. module 행동 안 바꿈.

evaluation 위해 거의 항상 둘 다 원함: 올바른 행동 위해 model.eval(), 속도 위해 torch.no_grad() (또는 inference_mode()).

Code

no_grad 와 inference_mode — context manager·python
import torch
import torch.nn as nn

model = nn.Linear(10, 2)
x = torch.randn(4, 10)

# no_grad: context manager
with torch.no_grad():
    y = model(x)
    print(y.requires_grad)        # False — no graph built

# inference_mode: stronger and faster
with torch.inference_mode():
    y = model(x)
    print(y.requires_grad)        # False

# As a decorator
@torch.inference_mode()
def predict(model, x):
    return model(x)
detach — single tensor disconnect·python
import torch

x = torch.tensor(3.0, requires_grad=True)
y = x ** 2

# Logging the value — don't drag autograd along
loss_value = y.detach().item()    # plain Python float

# A target network in RL — gradients should NOT flow into it
target = (x ** 2).detach()         # treated as a constant from now on

# detach IS view-shaped — same storage, different graph status
print(y.detach().data_ptr() == y.data_ptr())  # True
완전한 evaluation idiom·python
import torch
import torch.nn as nn

def evaluate(model, val_loader, criterion, device):
    model.eval()                       # behavior switch (dropout off, etc.)
    total_loss = total_correct = total_n = 0

    with torch.inference_mode():       # autograd switch (off)
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)

            total_loss += loss.item() * x.size(0)
            total_correct += (out.argmax(-1) == y).sum().item()
            total_n += x.size(0)

    return total_loss / total_n, total_correct / total_n

External links

Exercise

작은 model 잡기, inference 100 번 (a) autograd context 없이, (b) torch.no_grad, (c) torch.inference_mode 에서 시간. 작은 model 에서도 순서: no context > no_grad > inference_mode. 숫자 보관 — 대략 비율 알면 eval 시간 budget 잡을 때 유용.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.