C.W.K.
Stream
Lesson 02 of 13 · published

Q, K, V projection — 비대칭이 어디서 오나

~14 min · qkv, projections

Level 0Token
0 XP0/94 lessons0/10 achievements
0/120 XP to next level120 XP to go0% complete

각 토큰의 입력 벡터는 query, key, value 어느 역할에서든 똑같아. 비대칭은 같은 입력에서 Q, K, V를 만드는 독립적으로 학습되는 projection 행렬 셋에서 와.

Q = X · W_Q, K = X · W_K, V = X · W_V

세 행렬이 독립 학습. 학습 중 — 그리고 학습 중에 — 강조하는 feature가 갈라질 수 있고 실제로 갈라져. W_Q는 토큰을 "내가 뭘 찾고 싶은가?" subspace로, W_K는 "나는 검색당할 때 무엇인가?" subspace로, W_V는 "나는 다른 토큰의 표현에 무엇을 기여하나?" subspace로 투영하는 법을 학습.

흔한 shape

  • 표준 MHA: W_Q, W_K, W_V 각각 (d_model, d_model). projection 후 head로 reshape하면 각 head가 (seq_len, d_head) 텐서 봄, d_head = d_model / n_heads.
  • GQA / MQA: W_Q는 여전히 (d_model, d_model), W_K와 W_V는 더 작아 — n_heads 대신 n_kv_heads만큼의 행. (Track 4 GQA lesson에서 다룸.)
  • RoPE: 이 projection 다음, dot product 전에 Q와 K에 적용. V는 회전 안 함.

Code

Multi-head Q/K/V projection in PyTorch·python
class MultiHeadProjection(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
    def forward(self, x):
        B, L, _ = x.shape
        Q = self.W_q(x).view(B, L, self.n_heads, self.d_head)
        K = self.W_k(x).view(B, L, self.n_heads, self.d_head)
        V = self.W_v(x).view(B, L, self.n_heads, self.d_head)
        # Transpose to (B, n_heads, L, d_head) for attention
        return Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2)

External links

Exercise

위 MultiHeadProjection 구현. W_Q, W_K, W_V 랜덤 초기화. (4, 16, 512) 입력 배치 전달. n_heads=8일 때 Q, K, V 각각 (4, 8, 16, 64)로 나오는지 검증. 이제 n_heads를 16으로 두 배 늘리면(d_model=512 유지) 메모리는?

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.