C.W.K.
Stream
Lesson 10 of 12 · published

Output head — hidden state에서 logit으로

~8 min · output-head, weight-tying

Level 0Token
0 XP0/94 lessons0/10 achievements
0/120 XP to next level120 XP to go0% complete

최종 hidden state — 마지막 block 이후 residual stream — 의 shape는 (seq_len, d_model). 예측을 만들려면 그걸 vocab에 대한 logit shape (seq_len, vocab_size)로 바꿔야 해. 이게 output head.

기계적으로는 linear projection: logits = hidden @ W_lm.T, W_lm shape는 (vocab_size, d_model). 마지막 차원 softmax가 확률. 추론 시엔 보통 마지막 위치 logit만 필요(next-token 예측), 학습 시엔 모든 위치 logit을 병렬 계산.

Weight tying

많은 모델 — GPT-2, Llama, Mistral — 이 weight tying 사용 — input embedding 행렬과 output head가 같은 파라미터 공유. 수학적으로 W_lm = E.T. vocab × d_model 파라미터 절약(Llama 3 8B는 524M) + 만족스러운 대칭 — 비슷한 input embedding 갖는 토큰이 비슷한 output logit 프로필을 가져. 일부 더 큰 모델(GPT-3, GPT-4)은 weight tying 안 함 — 거대 스케일에선 파라미터 절감이 전체의 더 작은 비율이 되고, 분리하면 품질이 살짝 올라가.

Code

Tied output head·python
class TiedLM(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
    def forward(self, ids):
        x = self.embed(ids)             # (B, L, d_model)
        # ... transformer body ...
        x = final_norm(x)
        # Reuse embedding weights as the output projection
        logits = x @ self.embed.weight.T     # (B, L, vocab_size)
        return logits

External links

Exercise

weight tying 안 쓰는 작은 open-weight 모델 찾아(예: GPT-Neo 또는 일부 Mistral fine-tune). model.embed.weight랑 model.lm_head.weight 비교. 얼마나 비슷한가(행별 cosine similarity)? 고빈도 토큰 vs 희귀 토큰에서 비교 결과 달라지나?

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.