C.W.K.
Stream
Lesson 04 of 06 · published

Mixed Precision 과 Gradient Checkpointing

~12 min · amp, bf16, checkpoint, memory

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

항상 잡을 두 메모리 + 속도 도구

Track 2 에서 AMP 기본 cover. 이 lesson 이 modern 결합 모으고 gradient checkpointing 추가 — compute 를 메모리로 trade 하는 trick.

AMP 복습 — modern bf16 path

Ampere+ GPU (A100, RTX 30/40, H100) 와 Apple Silicon 에 autocast(device_type='cuda', dtype=torch.bfloat16)GradScaler 없이 선호. fp16 와 같은 속도, underflow drama 없음. fp16 + GradScaler 가 옛 hardware (V100, T4) 에 여전히 있지만 modern GPU 의 새 training run 엔 bf16 win.

Gradient checkpointing

보통 forward pass 가 backward 가 사용할 모든 intermediate activation 저장. 깊은 network (sequence length 4K 의 24-layer Transformer) 에 그 activation 이 메모리 지배. Gradient checkpointing 이 일부 activation 안 저장 — backward 중 demand 시 recompute. 메모리 50-70% drop, training time 약 30% 증가. OOM 안 되게 거의 항상 가치.

결합

AMP + gradient checkpointing + torch.compile + FSDP 다 결합. scaling up 시 추가 순서:

  1. eager fp32 시작 — 정확성 얻기.
  2. bf16 autocast 추가 — 무료 1.5–2x 속도.
  3. torch.compile 추가 — 또 ~1.5–2x.
  4. 원하는 batch size 에 여전히 OOM: gradient checkpointing.
  5. model 자체 안 들어감: FSDP.

Code

bf16 mixed precision — modern recipe·python
import torch
from torch.amp import autocast

# bf16 — preferred on Ampere+ and Apple Silicon. NO GradScaler.
for x, y in loader:
    x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
    optimizer.zero_grad()

    with autocast(device_type='cuda', dtype=torch.bfloat16):
        out = model(x)
        loss = criterion(out, y)

    loss.backward()
    optimizer.step()
fp16 mixed precision — 옛 GPU, GradScaler 필수·python
import torch
from torch.amp import autocast, GradScaler

scaler = GradScaler('cuda')

for x, y in loader:
    x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
    optimizer.zero_grad()

    with autocast(device_type='cuda', dtype=torch.float16):
        out = model(x)
        loss = criterion(out, y)

    scaler.scale(loss).backward()    # scale loss to avoid fp16 underflow
    scaler.step(optimizer)            # unscale + step (skip on inf/nan)
    scaler.update()                   # adjust scale dynamically
Gradient checkpointing — store 대신 recompute·python
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class CheckpointedTransformer(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock() for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            # Don't store this layer's activations; recompute in backward.
            x = checkpoint(layer, x, use_reentrant=False)
        return x

# For HuggingFace transformers, there's a one-liner:
# model.gradient_checkpointing_enable()
bf16 + compile + checkpointing 결합 — production stack·python
import torch
import torch.nn as nn
from torch.amp import autocast

model = MyTransformer().cuda()
model.gradient_checkpointing_enable()       # if available
model = torch.compile(model, mode='default')

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    for x, y in loader:
        x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type='cuda', dtype=torch.bfloat16):
            out = model(x)
            loss = criterion(out, y)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

External links

Exercise

transformer-shaped model. peak GPU 메모리와 per-step time 측정 세 configuration: (a) fp32 eager, (b) bf16 autocast, (c) bf16 + gradient checkpointing. 기대 pattern: bf16 이 메모리 ~50% 줄이고 빠름; checkpointing 이 또 30-50% 메모리 줄이지만 ~30% 느림.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.