C.W.K.
Stream
Lesson 04 of 06 · published

메모리 최적화

~22 min · bf16, gradient-checkpointing, flash-attention, memory

Level 0관찰자
0 XP0/43 lessons0/11 achievements
0/120 XP to next level120 XP to go0% complete

메모리 도구상자

Mixed precision (bf16/fp16)

32-bit 대신 16-bit 학습이 메모리 절반 + 현대 GPU에서 학습 가속.

bf16 vs fp16: bf16는 fp32랑 같은 범위(정밀도만 낮음)라 수치적으로 더 안정적. 하드웨어 지원하면 항상 bf16 선호(Ampere+: A100, RTX 3090+).

Gradient checkpointing

Compute를 메모리랑 trade — backward pass 중 activation 저장 대신 재계산. ~20% 느린 학습 비용에 VRAM 30~50% 감소.

Flash Attention 2

더 빠르고 메모리 효율적인 attention 구현. Attention 메모리를 O(n²)에서 O(n)으로 감소 + 학습 상당히 가속.

메모리 체크리스트 (영향 순)

  1. QLoRA(4-bit) 써 — 큰 모델에 단일 최대 승리.
  2. Gradient checkpointing 활성화.
  3. bf16 사용(가능하면 fp32 / fp16 X).
  4. Flash Attention 2 활성화.
  5. per_device batch 줄이고 gradient accumulation 늘려.
  6. max_seq_length 줄여.

Code

All four memory tricks combined·python
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from trl import SFTConfig
import torch

# 1. 4-bit QLoRA load
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# 2. Flash Attention 2 + bf16
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    quantization_config=bnb_config,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# 3. Gradient checkpointing + bf16 mixed precision
args = SFTConfig(
    bf16=True,                                        # mixed precision
    gradient_checkpointing=True,                      # checkpointing
    gradient_checkpointing_kwargs={"use_reentrant": False},
    per_device_train_batch_size=2,                    # small per-device
    gradient_accumulation_steps=8,                    # effective batch 16
    max_seq_length=2048,                              # tune for memory
)

External links

Exercise

24GB GPU에서 Llama 3.1 8B QLoRA 학습 런 셋업. 네 메모리 트릭을 하나씩 적용하면서 각각 후 peak VRAM 기록: 베이스라인(fp16) → +bf16 → +grad checkpoint → +Flash Attention 2 → +QLoRA. 누적 VRAM 절약 매핑.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.