C.W.K.
Stream
Lesson 04 of 08 · published

한 Train Step

~18 min · train-step, zero-grad, step

Level 0Curious
0 XP0/73 lessons0/11 achievements
0/120 XP to next level120 XP to go0% complete

5 줄 의식

Training 의 unit 은 한 step. 항상 같은 5 줄, 같은 순서:

  1. opt.zero_grad() — 이전 step 의 gradient 클리어.
  2. logits = model(xb) — forward pass.
  3. loss = loss_fn(logits, yb) — scalar loss.
  4. loss.backward() — autograd 로 gradient 계산.
  5. opt.step() — gradient 를 parameter 에 적용.

Optionally backward 와 step 사이에 clip_grad_norm_ 끼우고, step 다음에 scheduler.step(). 그게 전체 loop body. 다른 모든 건 그 주변 plumbing.

팁: 5 줄 의식을 외워서 화이트보드에 못 그리면 아직 muscle 없는 거. 20 번 써. 결국 문장 끝나기 전에 opt.zero_grad() 치게 돼.

Step 1 의 유명한 두 버그

zero_grad() 잊음 — gradient 가 step 사이 누적, 각 step 이 모든 이전 gradient 합 봄. Loss 높이 머물고, training collapse. Fix 는 한 줄.

loss.backward() 전에 opt.step() — stale gradient (이전 batch 의) 에 step. Loss curve 가 막연히 training 처럼 보이는데 model 이 실제로 현재 batch 학습 안 함.

Gradient accumulation (zero_grad 건너뛰는 올바른 방법)

때때로 VRAM 에 들어가는 것보다 큰 effective batch size 원함. Trick: 여러 batch 동안 zero_grad() 건너뛰고, loss 를 accumulation step 수로 나누고, 끝에 opt.step() + opt.zero_grad(). accum_steps 배 큰 single batch 와 equivalent gradient 생성.

원칙: 5 줄 의식이 training 의 가장 작은 unit. 외워. 그 다음 한 번에 한 변형 (gradient accumulation, mixed precision, gradient clipping) 학습. Basic version 건너뛰지 마.

Code

One train step, with all the optional pieces·python
from torch.nn.utils import clip_grad_norm_

opt.zero_grad()
logits = model(xb)
loss = loss_fn(logits, yb)
loss.backward()
clip_grad_norm_(model.parameters(), max_norm=1.0)   # optional
opt.step()
scheduler.step()                                    # optional
Gradient accumulation for big effective batch·python
ACCUM = 4
opt.zero_grad()
for micro_idx, (xb, yb) in enumerate(loader):
    logits = model(xb)
    loss = loss_fn(logits, yb) / ACCUM        # rescale for averaging
    loss.backward()                            # accumulates into .grad
    if (micro_idx + 1) % ACCUM == 0:
        clip_grad_norm_(model.parameters(), max_norm=1.0)
        opt.step()
        opt.zero_grad()

External links

Exercise

Training loop 의 opt.zero_grad() comment out 하고 loss curve 봐. 이제 ACCUM=4 로 gradient accumulation 시도 — 결과가 single 4x batch run 과 일치하는지 확인.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.