C.W.K.
Stream
Lesson 07 of 08 · published

Checkpointing — 재개에 필요한 모든 거 저장

~12 min · checkpoint, save, load, resume

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

full model 보다 state_dict — 매번

model 저장 두 방법:

  • state_dict (추천) — Python OrderedDict 로 parameter + buffer 만. 로드 위해 model 클래스 instantiate 하고 load_state_dict. code 변경, framework version 차이, 클래스 rename 에 robust.
  • torch.save(model) — 전체 Python 객체 pickle. fragile: 정확한 클래스 정의 필요, rename / refactor 깨뜨림.

항상 state_dict 사용.

'training' checkpoint 에 들어가는 것

inference 만이면 model state_dict 만. training 재개 위해선 필요:

  • Model state_dict
  • Optimizer state_dict (Adam 의 running moment 가 trivial 하게 재생성 안 됨)
  • Scheduler state_dict (현재 step / last_lr / 등)
  • 현재 epoch / step 번호
  • 지금까지 best validation metric (early stopping 위)
  • 정확 reproducibility 신경 쓰면 RNG state

weights_only=True

로드 시 weights_only=True 넘기기. PyTorch 2.x 가 결국 default 로 할 거지만, 명시적이 좋음. 임의 Python deserialize 거부 — tensor data 만. 악성 .pt 파일이 로드 시 코드 실행하는 (실재) 공격 클래스 방어.

Early stopping

val loss 개선 시 'best so far' checkpoint 저장; patience 추적, N epoch 개선 없이 지나면 중단. 다섯 줄 클래스 state, 낭비된 compute 방지.

Code

Save 와 load — 단순 inference case·python
import torch
import torch.nn as nn

class TinyMLP(nn.Module):
    def __init__(self): super().__init__(); self.fc = nn.Linear(10, 4)
    def forward(self, x): return self.fc(x)

# Train
model = TinyMLP()
torch.save(model.state_dict(), 'tiny_mlp.pt')

# Load — instantiate first, then restore weights
model2 = TinyMLP()
model2.load_state_dict(torch.load('tiny_mlp.pt', weights_only=True))
model2.eval()
Full training-resume checkpoint·python
import torch

def save_ckpt(path, model, optimizer, scheduler, epoch, best_val):
    torch.save({
        'epoch': epoch,
        'best_val': best_val,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'rng_state': torch.get_rng_state(),
    }, path)

def load_ckpt(path, model, optimizer, scheduler):
    ckpt = torch.load(path, weights_only=False)   # need False for non-tensor objects
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    scheduler.load_state_dict(ckpt['scheduler_state_dict'])
    torch.set_rng_state(ckpt['rng_state'])
    return ckpt['epoch'], ckpt['best_val']
Early stopping — 다섯 줄 state·python
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best = None
        self.should_stop = False

    def __call__(self, val_loss):
        if self.best is None or val_loss < self.best - self.min_delta:
            self.best = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True

# Usage
early = EarlyStopping(patience=5)
for epoch in range(100):
    train_one_epoch(...)
    val_loss, _ = evaluate(...)
    early(val_loss)
    if early.should_stop:
        print(f"Early stopping at epoch {epoch}")
        break

External links

Exercise

training loop 에 EarlyStopping pattern 추가. 50 epoch 까지 train 하지만 patience=5. training 종료 시 'best epoch was N (val_loss=X)' print. best 와 last checkpoint 둘 다 저장. best checkpoint 로드 가능하고 held-out batch 에 inference 돌릴 수 있는지 검증.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.