C.W.K.
Stream
Lesson 08 of 13 · published

Causal masking — decoder-only 모델이 학습 중 답을 못 보게 하는 법

~12 min · causal-mask, decoder

Level 0Token
0 XP0/94 lessons0/10 achievements
0/120 XP to next level120 XP to go0% complete

next-token prediction으로 학습되는 decoder-only 모델에서, 위치 t는 위치 t+1의 토큰을 예측해야 해. 학습 중엔 모든 위치의 loss를 병렬 계산하고 싶은데, 모델이 forward pass 중 미래 토큰을 못 훔쳐 보게 해야 해. 해결책은 단순하고 핵심적이야 — causal mask.

구체적으론, softmax 전에 attention score 행렬의 상삼각에 -∞ 텐서를 더해. softmax 후 그 항목들이 정확히 0이 돼 — 해당 위치들이 완전히 무시. 위치 3은 0, 1, 2, 3에 attend할 수 있지만 4, 5, ...엔 못 함. 모델이 모든 위치를 한꺼번에 계산하는 병렬성을 얻으면서, 위치별로는 과거만 본 것처럼 행동.

Code

Causal mask construction·python
import torch

def causal_mask(seq_len, device='cpu'):
    # Returns (seq_len, seq_len) bool tensor
    # True at positions (i, j) where j > i (future)
    return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()

# Inside attention:
# scores = Q @ K.transpose(-2, -1) / d_k**0.5
# scores = scores.masked_fill(causal_mask(seq_len), float('-inf'))
# weights = torch.softmax(scores, dim=-1)
# After softmax, future positions have exactly 0 weight.
Inspect the resulting attention weights·python
seq_len = 6
mask = causal_mask(seq_len)
print(mask.int())
# tensor([[0, 1, 1, 1, 1, 1],
#         [0, 0, 1, 1, 1, 1],
#         [0, 0, 0, 1, 1, 1],
#         [0, 0, 0, 0, 1, 1],
#         [0, 0, 0, 0, 0, 1],
#         [0, 0, 0, 0, 0, 0]])
# After softmax with -inf substitution, position 0 attends only to itself,
# position 1 to {0, 1}, ... position 5 attends to all six.

External links

Exercise

causal_mask 구현. 앞의 multi-head attention에 끼워넣기. 작은 모델을 copy task로 학습. 그 다음 실수로 mask 제거하고 재학습. 모델 loss가 더 빨리 떨어지는 거(치팅 중) + 추론 시 생성 품질이 무너지는 거 관찰. 교훈 문서화.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.