C.W.K.
Stream
Lesson 10 of 10 · published

잘못 행동하는 Gradient 디버깅

~14 min · debug, anomaly, nan, inf, diagnostics

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

backward 가 잘못 가면 도구 있어

발산하는 training, NaN loss, 신비롭게 0 또는 무한대가 되는 gradient — 이게 고전적 'model 이 뭐가 잘못' 증상. 네 PyTorch-specific 도구가 진단 표면 대부분 cover:

  1. Gradient norm logging. backward 후 clip_grad_norm_(model.parameters(), float('inf')) 추가하고 반환값 log. 1e6 spike 가 '방금 무슨 일' 신호.
  2. NaN / Inf 감지. 매 step 에 torch.isfinite(loss) 가 싼 보험 — Inf 도 NaN 도 둘 다 잡아. torch.autograd.set_detect_anomaly(True) 가 NaN backward 만든 정확한 op tag — 느리지만 막혔을 때 priceless. 주의 — scope 어긋남: anomaly detection 은 NaN 에만 발사, Inf 에는 안 해 (아래 'Inf vs NaN' 섹션).
  3. Per-parameter gradient stats. backward 후 model.named_parameters() 걸으며 각 mean/std/max print. Vanishing-gradient 버그가 이 view 에서 명백.
  4. shape mismatch 용 gradient hook. layer output 에 hook 등록해서 gradient shape 가 기대와 일치 검증 — skip connection 또는 custom layer 추가할 때 유용.

Inf vs NaN — anomaly detection 의 scope 는 좁아

대부분 튜토리얼이 건너뛰는 갭: set_detect_anomaly(True) 가 raise 하는 건 backward output 의 NaN 뿐, Inf 아니야. 같은 함수도 input 따라 둘 중 한 bucket 으로 떨어져.

Inf 출처 — anomaly detection 침묵, torch.isfinite() 로 잡아:

  • sqrt(0).backward() — derivative 가 1/(2·sqrt(x)) = ∞ at x=0. Fix: (x + 1e-9).sqrt().
  • log(0).backward() — derivative 가 1/x = ∞ at x=0. Fix: (p + 1e-9).log().

NaN 출처 — anomaly detection 발사, 정확한 op 지목:

  • sqrt(-1), log(-1) — forward 에서 NaN, backward 로 propagate.
  • 0 / 0 — 보통 norm 으로 normalize 에서 norm 이 0.
  • 0 * log(0) — 0 × (-∞) = NaN.
  • log(softmax(x)) 에서 softmax 가 한 entry 에 ~0 으로 collapse — NaN backward 만듦 (수치 안정성 위해 F.log_softmax 써 — exercise 참고).
  • fp16 overflow 가 0×∞ 로 가는 경우 — bf16 로 전환 또는 GradScaler 추가.

Code

set_detect_anomaly — Inf trap (raise 안 함)·python
import torch

# Turn on at the start of training when debugging — turn OFF for real runs
torch.autograd.set_detect_anomaly(True)

# Inf trap: anomaly detection WON'T catch this — sqrt(0) backward = inf, not nan
x = torch.tensor(0.0, requires_grad=True)
y = torch.sqrt(x)
y.backward()
print(x.grad, torch.isfinite(x.grad))   # tensor(inf), tensor(False) — no RuntimeError
# Use isfinite() to catch Inf; anomaly detection only fires on NaN.
set_detect_anomaly — NaN backward 는 raise 함·python
import torch

torch.autograd.set_detect_anomaly(True)

# Same sqrt function, different domain — sqrt(-1) IS caught
x = torch.tensor(-1.0, requires_grad=True)
y = torch.sqrt(x)            # forward: nan, propagates to backward
y.backward()
# RuntimeError: Function 'SqrtBackward0' returned nan values in its 0th output.
# The traceback points at the exact forward call that created the nan-producing op.
Per-parameter gradient stats — 진단 walk·python
import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 10))
loss = nn.functional.cross_entropy(model(torch.randn(8, 100)), torch.randint(0, 10, (8,)))
loss.backward()

print(f"{'name':25s}{'mean':>12s}{'std':>12s}{'max_abs':>12s}{'has_nan':>10s}")
for name, p in model.named_parameters():
    if p.grad is None:
        print(f"{name:25s}  NO GRADIENT")
        continue
    g = p.grad
    print(f"{name:25s}{g.mean():12.2e}{g.std():12.2e}{g.abs().max():12.2e}"
          f"{str(torch.isnan(g).any().item()):>10s}")
싼 NaN 보험 — 모든 training loop·python
import torch
import torch.nn as nn

model = nn.Linear(10, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for step, (x, y) in enumerate(loader):
    optimizer.zero_grad()
    loss = nn.functional.mse_loss(model(x), y)

    # Cheap. Worth it. isfinite() catches both Inf and NaN.
    if not torch.isfinite(loss):
        print(f"Step {step}: non-finite loss = {loss.item()}; halting.")
        # Save a snapshot of the offending batch for later analysis
        torch.save({'x': x, 'y': y, 'state': model.state_dict()},
                   f'nan_snapshot_step_{step}.pt')
        break

    loss.backward()
    optimizer.step()

External links

Exercise

의도적 NaN 만들기: logit 을 softmax 통과시키고 result 의 log 잡기. autograd.set_detect_anomaly(True) 로 backward 실패하게 하고 가리킨 정확한 op 읽기. 이제 fix (log(softmax(x)) 대신 F.log_softmax 사용) 하고 backward clean 인지 확인.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.