모든 PyTorch training loop 는 같은 skeleton: 각 epoch 에 training loader 반복, loss 계산, backprop, optimizer step. 주기적으로 (보통 epoch 마다 또는 N step 마다) model.eval() 로 validation set 에서 evaluate. 가장 좋은 checkpoint 저장. Validation metric 개선 멈출 때까지 반복.
이 skeleton 이 universal 해서 PyTorch user 들은 한 번 손으로 쓰고 재사용하거나, higher-level framework (PyTorch Lightning, Hugging Face Trainer, Accelerate) 가 대신 쓰게 함. 둘 다 합리적. 본인 거 쓰는 게 better learning, framework 쓰는 게 better production 선택.
팁: Framework 채택 전에 적어도 한 번 손으로 training loop 써. Framework 가 typing 절약하지만, 뭔가 잘못됐을 때 디버깅에 필요한 step 도 숨겨.
Junior training loop 에서 잘못되는 것
opt.zero_grad() 잊음 — gradient 가 이전 step 에서 누적.
Validation 에서 model.eval() 잊음 — dropout 과 batchnorm misbehave, metric noisy.
GPU 에서 loss 계산하고 .item() 없이 CPU 에서 print/log — invisible CUDA sync 곳곳.
Probability 의 argmax 와 logit 의 argmax inconsistent 하게 accuracy 계산 — 운으로 같은 결과, softmax 변하면 깨짐.
Best validation step 이 아니라 training 끝에 model state dict 저장 — 잘못된 model ship.
Non-negotiable 추가
Skeleton 외에 모든 진지한 training loop 에 있어: gradient clipping, learning-rate scheduling, GPU 에 mixed-precision (autocast + GradScaler), 주기적 checkpoint 저장, held-out set 에 대한 validation, 본인 monitoring tool 에서 읽을 수 있는 logging. 다음 track 이 각각을 mechanical 하게 만들어.
원칙: Training loop 는 progress 잃지 않고 죽이고 resume 할 수 있을 때까지 'done' 아냐. 모든 저장된 checkpoint 가 model, optimizer, scheduler, scaler, step number 포함해야 해.
Code
Minimal but honest training loop·python
import torch
from torch import nn, optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyModel().to(device)
opt = optim.AdamW(model.parameters(), lr=3e-4)
sch = warmup_cosine(opt, warmup_steps=200, total_steps=5_000)
loss_fn = nn.CrossEntropyLoss()
best_val_acc = 0.0
step = 0
for epoch in range(10):
model.train()
for xb, yb in train_loader:
xb, yb = xb.to(device), yb.to(device)
opt.zero_grad()
logits = model(xb)
loss = loss_fn(logits, yb)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
opt.step(); sch.step()
step += 1
model.eval()
correct, total = 0, 0
with torch.inference_mode():
for xb, yb in val_loader:
xb, yb = xb.to(device), yb.to(device)
preds = model(xb).argmax(dim=-1)
correct += (preds == yb).sum().item()
total += yb.size(0)
val_acc = correct / total
print(f"epoch {epoch} step {step} val_acc={val_acc:.4f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
"model": model.state_dict(),
"optimizer": opt.state_dict(),
"scheduler": sch.state_dict(),
"step": step, "val_acc": val_acc,
}, "best.pt")