C.W.K.
Stream
Lesson 05 of 13 · published

Multi-head attention — 병렬 subspace들

~16 min · multi-head, mha

Level 0Token
0 XP0/94 lessons0/10 achievements
0/120 XP to next level120 XP to go0% complete

Multi-Head Attention(MHA)은 attention 연산을 h개로 병렬 실행, 각각 낮은 차원 subspace에서, 결과 concat해서 d_model로 다시 projection. 직관: 모든 관계 패턴을 한꺼번에 잡아야 하는 attention pass 하나 대신 h개의 짧은 "specialist" 줘 — head마다 다른 패턴 학습 가능.

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_O
head_i = Attention(Q · W_i^Q, K · W_i^K, V · W_i^V)

구체적으로: h 선택(보통 8, 12, 32, 64). d_head = d_model / h(보통 64 또는 128). 각 head가 자기 W^Q_i, W^K_i, W^V_i 가짐, shape (d_model, d_head). h개의 (seq_len, d_head) 출력을 (seq_len, d_model) 텐서 하나로 concat. shape (d_model, d_model)인 최종 output projection W_O로 head 간 정보 섞음.

왜 이게 중요한가

학습된 모델에서 관찰되는 head들은 종종 특화돼:

  • 일부 head는 직전 토큰에만 attend("position −1 head").
  • 일부는 최근 context에서 같은 토큰에 attend("induction head").
  • 일부는 문법적으로 관련된 토큰에 attend(subject ↔ verb).
  • 일부는 거의 균일하게 attend해서 의미 평균화.
이 emergent specialization은 Transformer에 대한 가장 강건한 경험적 발견 중 하나야.

Code

Multi-head attention forward pass·python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads, self.d_head = n_heads, d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
    def forward(self, x, mask=None):
        B, L, d = x.shape
        H, dh = self.n_heads, self.d_head
        Q = self.W_q(x).view(B, L, H, dh).transpose(1, 2)   # (B, H, L, dh)
        K = self.W_k(x).view(B, L, H, dh).transpose(1, 2)
        V = self.W_v(x).view(B, L, H, dh).transpose(1, 2)
        out = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)
        out = out.transpose(1, 2).contiguous().view(B, L, d)
        return self.W_o(out)

External links

Exercise

위 MultiHeadAttention 구현. 작은 버전을 copy task로 학습. 학습 후 샘플 시퀀스에 대해 head별 attention map 플롯. 'previous-token' head, 'first-token' head, uniform head 찾을 수 있나? 이게 specialization에 대해 뭘 시사하나?

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.