C.W.K.
Stream
Lesson 06 of 10 · published

Initialization: Xavier 와 He

~18 min · init, xavier, he

Level 0Curious
0 XP0/73 lessons0/11 achievements
0/120 XP to next level120 XP to go0% complete

왜 initialization 중요한가

Weight 를 너무 작게 initialize 하면 signal 이 network 통해 vanish, gradient 도 vanish 해서 돌아와. 너무 크면 signal explode, gradient explode. 올바른 scale 이 activation variance 를 layer 사이 대략 일정 유지, depth 에서 gradient chain 곱을 well-conditioned 유지.

Xavier (a.k.a. Glorot) initialization: std = sqrt(2 / (fan_in + fan_out)). Tanh/sigmoid activation 에 designed. He initialization: std = sqrt(2 / fan_in). ReLU (input 절반 죽이니까 보상하려 약간 더 큰 weight 필요) 에 designed. PyTorch nn.Linear default 가 He 의 uniform variant.

팁: PyTorch default 가 ReLU/GELU network 에 sensible. 첫 custom layer 또는 non-standard activation 쓰면 init 명시적 override — default 가 silently 깨짐.

Custom-init 패턴

Non-default architecture 에 single init_weights function 작성, model.apply(init_weights) 로 적용. 이게 module tree recursive walk 하고 모든 submodule 에 본인 function 호출. Transformer 코드의 standard practice.

특수 case

Output layer — 가끔 작은 scale 로 init (e.g., transformer LM head 에 0.02 std). Embedding — transformer 에 Normal(0, 0.02), word embedding 에 uniform. BatchNorm — γ=1, β=0 default, 일부 recipe 가 residual block 의 마지막 BN 에 γ=0 써서 identity 로 시작.

원칙: Initialization 은 작동할 때까지 tune 하고 그 다음 안 만지는 거 중 하나. ReLU/GELU/SiLU network 에 He, tanh/sigmoid 에 Xavier, 나머지엔 PyTorch default — 이유 있을 때까지.

Code

He init for a ReLU network·python
import torch.nn as nn
import math

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Embedding):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)

model = MyModel()
model.apply(init_weights)

External links

Exercise

같은 MLP 를 (1) PyTorch default, (2) He init, (3) 의도적으로 나쁜 init (std=10) 으로 train. Loss curve 그려. 나쁜 init 은 train 거부하거나 diverge 해야 — 정확히 initialization 이 막아야 할 거.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.