Multi-Head Attention(MHA)은 attention 연산을 h개로 병렬 실행, 각각 낮은 차원 subspace에서, 결과 concat해서 d_model로 다시 projection. 직관: 모든 관계 패턴을 한꺼번에 잡아야 하는 attention pass 하나 대신 h개의 짧은 "specialist" 줘 — head마다 다른 패턴 학습 가능.
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_O
head_i = Attention(Q · W_i^Q, K · W_i^K, V · W_i^V)
head_i = Attention(Q · W_i^Q, K · W_i^K, V · W_i^V)
구체적으로: h 선택(보통 8, 12, 32, 64). d_head = d_model / h(보통 64 또는 128). 각 head가 자기 W^Q_i, W^K_i, W^V_i 가짐, shape (d_model, d_head). h개의 (seq_len, d_head) 출력을 (seq_len, d_model) 텐서 하나로 concat. shape (d_model, d_model)인 최종 output projection W_O로 head 간 정보 섞음.
왜 이게 중요한가
학습된 모델에서 관찰되는 head들은 종종 특화돼:
- 일부 head는 직전 토큰에만 attend("position −1 head").
- 일부는 최근 context에서 같은 토큰에 attend("induction head").
- 일부는 문법적으로 관련된 토큰에 attend(subject ↔ verb).
- 일부는 거의 균일하게 attend해서 의미 평균화.