Checkpoint 는 다시 load 할 수 있는 dictionary. 포함해야 할 것: model state dict, optimizer state dict, scheduler state dict, GradScaler state dict (mixed precision 쓰면), step number, epoch number, config (이 checkpoint 가 뭘 train 했는지 알게).
Model state dict 만 저장은 inference 엔 OK 인데 training resume 엔 useless. Optimizer 와 scheduler 가 model 만으로 reconstruct 못 할 running average 와 momentum 들고 있어.
팁: 항상 optimizer 도 저장. 누가 'model 만' 에서 resume 했는데 carefully tuned learning rate 가 0 부터 다시 시작하는 거 본 횟수 셀 수 없어.
Best-checkpoint vs latest-checkpoint
Best — 지금까지 가장 높은 validation metric 의 checkpoint. Ship 할 거.
Latest — metric 무관 가장 최근 checkpoint. Crash 후 resume 할 거.
둘 다 있어야 해. Best 는 production, latest 는 engineering. 다른 filename 에 저장해서 한 쪽이 다른 쪽 overwrite 안 하게.
저장 위치
짧은 training 엔 local disk OK. Multi-day run 에는 S3 / GCS / fileserver 에 저장 — node 죽음에 살아남게. 일부 training framework (PyTorch Lightning, Composer, HF Trainer) 가 처리해 줘.
원칙: Checkpointing 은 보험. 80% 에서 뭔가 crash 나서 GPU 시간 한 주 낭비 안 하게 하는 가장 싼 일.
Code
Save and resume a full training state·python
import torch
def save_ckpt(path, model, opt, sch, scaler, step, epoch, cfg, metric):
torch.save({
"model": model.state_dict(),
"optimizer": opt.state_dict(),
"scheduler": sch.state_dict() if sch else None,
"scaler": scaler.state_dict() if scaler else None,
"step": step,
"epoch": epoch,
"config": cfg,
"best_metric": metric,
}, path)
def load_ckpt(path, model, opt=None, sch=None, scaler=None):
ckpt = torch.load(path, map_location="cpu")
model.load_state_dict(ckpt["model"])
if opt and ckpt.get("optimizer"):
opt.load_state_dict(ckpt["optimizer"])
if sch and ckpt.get("scheduler"):
sch.load_state_dict(ckpt["scheduler"])
if scaler and ckpt.get("scaler"):
scaler.load_state_dict(ckpt["scaler"])
return ckpt["step"], ckpt["epoch"], ckpt.get("best_metric", 0.0)
5 epoch train, checkpoint 저장, process 죽이기, resume 해서 5 epoch 더 train. Validation curve 가 resume 통과 continuous 해야 — 어느 방향이든 jump 없음. Jump 있으면 본인 save/load 에 뭔가 빠짐.
Progress
Progress is local-only — sign in to sync across devices.