C.W.K.
Stream
Lesson 03 of 06 · published

FSDP — single GPU 에 안 들어가는 model train

~14 min · fsdp, shard, scale

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

model 이 한 GPU 에 너무 클 때

DDP 가 모든 GPU 에 model replicate. model 자체가 안 들어가면 깨짐. FullyShardedDataParallel (FSDP) 가 model parameter, gradient, AND optimizer state 를 GPU 들에 shard. 각 GPU 가 1/N 만 들고, forward 와 backward 위 필요한 slice gather 한 후 release. 24GB GPU cluster 에 70B-parameter model train 가능하게 만든 거.

Sharding strategy

  • FULL_SHARD — parameter, gradient, optimizer state shard. 가장 메모리 효율. 가장 큰 model 의 default 선택.
  • SHARD_GRAD_OP — gradient 와 optimizer state 만 shard; parameter 는 replicate. parameter all-gather 없어서 FULL_SHARD 보다 빠르지만 더 메모리.
  • NO_SHARD — DDP 등가. 같은 training script 공유하는 unit test 에 유용.

FSDP vs FSDP2

FSDP1 이 module wrap 하는 legacy API; FSDP2 (fully_shard() per-parameter approach) 가 새 in-development API, 깔끔한 interface 와 torch.compile 더 나은 composability. 2026 의 새 코드, bleeding-edge PyTorch 타겟이면 FSDP2 선호 — 근데 FSDP1 이 여전히 production-stable, 문서가 더 성숙.

같은 거 유지

training loop 안에선 FSDP 가 DDP 와 동일해 보여. model wrap 하고 평소대로 진행. 복잡도가 setup (nested module 의 auto-wrap policy, mixed-precision 설정) — loop 에 아님.

Code

FSDP1 — wrap-the-model API·python
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

dist.init_process_group("nccl")
rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(rank)

model = MyHugeModel().cuda(rank)
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    device_id=rank,
)

# The training loop is the SAME as DDP
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
    for x, y in loader:
        optimizer.zero_grad()
        loss = criterion(model(x.cuda(rank)), y.cuda(rank))
        loss.backward()
        optimizer.step()
Auto-wrap — nested module FSDP 가 처리·python
import torch
import functools
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import torch.nn as nn

# Wrap any nn.Module submodule with > 100M params automatically
auto_wrap_policy = functools.partial(
    size_based_auto_wrap_policy,
    min_num_params=100_000_000,
)

model = FSDP(
    MyHugeModel(),
    auto_wrap_policy=auto_wrap_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    device_id=rank,
)

# For Transformer-shaped models, prefer transformer_auto_wrap_policy
# which wraps each block — much better for FSDP performance
FSDP2 — modern per-parameter API·python
import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard

# FSDP2 is per-parameter and per-block, not whole-model
class TransformerStack(nn.Module):
    def __init__(self):
        super().__init__()
        self.blocks = nn.ModuleList([TransformerBlock() for _ in range(24)])
    def forward(self, x):
        for b in self.blocks:
            x = b(x)
        return x

model = TransformerStack()

# Apply FSDP2 to each block — finer control than FSDP1's wrap policy
for layer in model.blocks:
    fully_shard(layer)
fully_shard(model)

# FSDP2 advantages:
#   - Cleaner interaction with torch.compile
#   - Per-parameter sharding (mix sharded / non-sharded freely)
#   - More composable with other parallelism strategies (TP, PP)

External links

Exercise

single GPU 에 편안히 fit 하는 model 잡기. multi-GPU box 에 FULL_SHARD 의 FSDP1 통과. per-GPU 메모리 사용이 single-GPU 사용의 대략 1/N 인지 검증 — 그게 sharding 실제로 작동. 이 scale 에 DDP vs speedup 기대 X; FSDP 가 model 이 다른 방법으로 안 들어갈 때 win.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.