Gradient 가 deep network 통과 흐를 때, 각 layer activation distribution 이 drift (Internal Covariate Shift, original 2015 framing). Normalization layer 가 각 layer input 이 stable mean 과 variance 갖게 강제, gradient 를 well-conditioned 유지하고 더 깊게, 더 높은 learning rate 로, 더 reliably train 가능.
BatchNorm
Batch dimension 따라 normalize: feature 마다 batch mean 빼고 batch standard deviation 으로 나눠. Feature 당 learned scale 과 shift 추가. CNN (특히 ResNet-style) 에 사용. 실패 모드: 작은 batch size (batch=1, batch=8) 가 unreliable statistics, train 과 eval 에서 다르게 행동 (eval 에 running statistics) — 유명한 버그 source.
LayerNorm
각 example 안에서 feature dimension 따라 normalize: example 마다 per-example mean 빼고 per-example standard deviation 으로 나눠. Transformer 에 사용 (variable-length sequence 에 batch dimension 이 awkward 라서). Train/eval 구분 없음 — 어느 쪽이든 같은 행동.
팁: 어느 거 쓸지 결정 못 하면: vision CNN (큰 batch, 고정 shape) 에 BatchNorm, transformer (작거나 variable batch shape) 에 LayerNorm. RMSNorm (stripped-down LayerNorm) 이 LLaMA 와 다른 현재 LLM 의 modern variant.
어디 둘지
흔한 패턴 두 개: Pre-norm (layer 전 norm): x = x + sublayer(LayerNorm(x)). 매우 deep transformer 에 더 stably train. Post-norm (layer 후 norm): x = LayerNorm(x + sublayer(x)). Original transformer convention. Modern transformer 가 default 로 pre-norm.
원칙: Normalization 이 2026 년에 deep network trainable 하게 만드는 거. 잘못된 종류 (transformer 에 BN, small CNN 에 LN) 골라도 보통 작동하지만 accuracy 또는 stability 의 10-20% 잃어. Type 을 architecture 에 매치.
Code
BatchNorm vs LayerNorm in code·python
import torch, torch.nn as nn
# CNN with BatchNorm
cnn = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64), # normalize across batch+spatial, per channel
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
)
# Transformer block with LayerNorm
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model), nn.GELU(),
nn.Linear(4 * d_model, d_model),
)
def forward(self, x):
# Pre-norm pattern
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), need_weights=False)[0]
x = x + self.ffn(self.norm2(x))
return x