C.W.K.
Stream
Lesson 04 of 08 · published

nn.Sequential 과 Composition Pattern

~10 min · sequential, container, composition

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

세 container, 하나의 pattern

PyTorch 의 세 핵심 composition module:

  • nn.Sequential(*modules) — module 을 순서대로 돌리고, 한 output 이 다음에 feed. custom forward() 불필요.
  • nn.ModuleList([modules]) — PyTorch 가 볼 수 있는 Python list. forward() 직접 짜고 원하는 대로 iterate.
  • nn.ModuleDict({name: module}) — 같은 idea, dict 모양. multi-head model 이나 branching architecture 에 유용.

'가능하면 Sequential, 어쩔 수 없으면 ModuleList' 룰이 코드의 ~80% 에 맞아. Sequential 이 깔끔하게 읽히고 print(model) 에 잘 나타남. ModuleList 는 branch, skip connection, 또는 dynamic depth (예: N 이 config 에서 오는 N-layer Transformer) 가 있는 거 위.

함정: 일반 Python collection 안 됨

plain listdict (nn 버전 아님) 에 layer 저장하면, model.parameters() 에 등장, .to(device) 로 이동, state_dict() 에 저장. 버그가 silent — model 은 train 하지만 PyTorch 가 볼 수 있는 layer 만. 항상 nn.Module* 버전 사용.

Code

Sequential — 쉬운 case·python
import torch.nn as nn

# Plain Sequential
mlp = nn.Sequential(
    nn.Linear(784, 256),
    nn.GELU(),
    nn.Dropout(0.1),
    nn.Linear(256, 10),
)

# Named — useful for inspection and partial freezing
from collections import OrderedDict
mlp_named = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(784, 256)),
    ('act', nn.GELU()),
    ('drop', nn.Dropout(0.1)),
    ('fc2', nn.Linear(256, 10)),
]))

# Now you can refer by name
print(mlp_named.fc1)
ModuleList — 통제 필요할 때·python
import torch
import torch.nn as nn

class FlexibleMLP(nn.Module):
    def __init__(self, sizes):
        super().__init__()
        # ModuleList — a list PyTorch can see
        self.layers = nn.ModuleList([
            nn.Linear(sizes[i], sizes[i+1])
            for i in range(len(sizes) - 1)
        ])
        self.act = nn.GELU()

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < len(self.layers) - 1:    # no activation on final layer
                x = self.act(x)
        return x

m = FlexibleMLP([784, 256, 128, 64, 10])
print(sum(p.numel() for p in m.parameters()))  # 234,506
silent-bug 예 — plain list·python
import torch
import torch.nn as nn

class BrokenModel(nn.Module):
    def __init__(self):
        super().__init__()
        # WRONG: regular list. PyTorch can't see these.
        self.layers = [nn.Linear(10, 10) for _ in range(3)]

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

m = BrokenModel()
print(list(m.parameters()))   # [] — empty!
m.to('cpu')                    # silently moves nothing
m(torch.randn(4, 10))          # works at first call (CPU only)

# Fix: use nn.ModuleList instead

External links

Exercise

두 번째 code block 의 FlexibleMLP 를 nn.ModuleDict 사용하게 변환, 'layer_0', 'layer_1' 등 이름. model.layers['layer_0'] 가 첫 Linear 반환하고 모든 parameter 가 여전히 model.parameters() 에 등장하는지 확인.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.