C.W.K.
Stream
Lesson 07 of 08 · published

Checkpoint

~18 min · checkpoint, save, resume

Level 0Curious
0 XP0/73 lessons0/11 achievements
0/120 XP to next level120 XP to go0% complete

Checkpoint 가 실제로 담는 것

Checkpoint 는 다시 load 할 수 있는 dictionary. 포함해야 할 것: model state dict, optimizer state dict, scheduler state dict, GradScaler state dict (mixed precision 쓰면), step number, epoch number, config (이 checkpoint 가 뭘 train 했는지 알게).

Model state dict 만 저장은 inference 엔 OK 인데 training resume 엔 useless. Optimizer 와 scheduler 가 model 만으로 reconstruct 못 할 running average 와 momentum 들고 있어.

팁: 항상 optimizer 도 저장. 누가 'model 만' 에서 resume 했는데 carefully tuned learning rate 가 0 부터 다시 시작하는 거 본 횟수 셀 수 없어.

Best-checkpoint vs latest-checkpoint

Best — 지금까지 가장 높은 validation metric 의 checkpoint. Ship 할 거.

Latest — metric 무관 가장 최근 checkpoint. Crash 후 resume 할 거.

둘 다 있어야 해. Best 는 production, latest 는 engineering. 다른 filename 에 저장해서 한 쪽이 다른 쪽 overwrite 안 하게.

저장 위치

짧은 training 엔 local disk OK. Multi-day run 에는 S3 / GCS / fileserver 에 저장 — node 죽음에 살아남게. 일부 training framework (PyTorch Lightning, Composer, HF Trainer) 가 처리해 줘.

원칙: Checkpointing 은 보험. 80% 에서 뭔가 crash 나서 GPU 시간 한 주 낭비 안 하게 하는 가장 싼 일.

Code

Save and resume a full training state·python
import torch

def save_ckpt(path, model, opt, sch, scaler, step, epoch, cfg, metric):
    torch.save({
        "model": model.state_dict(),
        "optimizer": opt.state_dict(),
        "scheduler": sch.state_dict() if sch else None,
        "scaler": scaler.state_dict() if scaler else None,
        "step": step,
        "epoch": epoch,
        "config": cfg,
        "best_metric": metric,
    }, path)

def load_ckpt(path, model, opt=None, sch=None, scaler=None):
    ckpt = torch.load(path, map_location="cpu")
    model.load_state_dict(ckpt["model"])
    if opt and ckpt.get("optimizer"):
        opt.load_state_dict(ckpt["optimizer"])
    if sch and ckpt.get("scheduler"):
        sch.load_state_dict(ckpt["scheduler"])
    if scaler and ckpt.get("scaler"):
        scaler.load_state_dict(ckpt["scaler"])
    return ckpt["step"], ckpt["epoch"], ckpt.get("best_metric", 0.0)

External links

Exercise

5 epoch train, checkpoint 저장, process 죽이기, resume 해서 5 epoch 더 train. Validation curve 가 resume 통과 continuous 해야 — 어느 방향이든 jump 없음. Jump 있으면 본인 save/load 에 뭔가 빠짐.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.