Causal masking — decoder-only 모델이 학습 중 답을 못 보게 하는 법
~12 min · causal-mask, decoder
Level 0Token
0 XP0/94 lessons0/10 achievements
0/120 XP to next level120 XP to go0% complete
next-token prediction으로 학습되는 decoder-only 모델에서, 위치 t는 위치 t+1의 토큰을 예측해야 해. 학습 중엔 모든 위치의 loss를 병렬 계산하고 싶은데, 모델이 forward pass 중 미래 토큰을 못 훔쳐 보게 해야 해. 해결책은 단순하고 핵심적이야 — causal mask.
구체적으론, softmax 전에 attention score 행렬의 상삼각에 -∞ 텐서를 더해. softmax 후 그 항목들이 정확히 0이 돼 — 해당 위치들이 완전히 무시. 위치 3은 0, 1, 2, 3에 attend할 수 있지만 4, 5, ...엔 못 함. 모델이 모든 위치를 한꺼번에 계산하는 병렬성을 얻으면서, 위치별로는 과거만 본 것처럼 행동.
causal_mask 구현. 앞의 multi-head attention에 끼워넣기. 작은 모델을 copy task로 학습. 그 다음 실수로 mask 제거하고 재학습. 모델 loss가 더 빨리 떨어지는 거(치팅 중) + 추론 시 생성 품질이 무너지는 거 관찰. 교훈 문서화.
Progress
Progress is local-only — sign in to sync across devices.