C.W.K.
Stream
Lesson 13 of 13 · published

GQA와 MQA — KV cache 줄이기

~12 min · gqa, mqa, memory

Level 0Token
0 XP0/94 lessons0/10 achievements
0/120 XP to next level120 XP to go0% complete

KV cache는 긴 컨텍스트 추론의 용. 표준 MHA는 query head마다 하나씩 K와 V를 저장, 컨텍스트 길이와 head 수 둘 다에 선형 스케일. Grouped-Query Attention(GQA)과 그 극단 사촌 Multi-Query Attention(MQA)이 여러 Q head 간 K, V를 공유해서 이걸 직접 공격.

shape들

변종Q headsKV headsKV cache 감소사용
MHAHH1× (베이스라인)BERT, GPT-3, GPT-2
GQAHG < HH/G ×Llama 3, Mistral, Gemma 3, Mixtral
MQAH1H ×초기 PaLM, Falcon

Llama 3.3 70B는 query head 64개, KV head 8개 — GQA 그룹 크기 8. query head 8개로 된 각 그룹이 K 하나, V 하나 projection 공유. KV cache가 8배 줄어. GQA가 MQA보다 KV 다양성이 더 많고 cache 절감의 대부분을 유지하니까 품질이 풀 MHA에 가까워.

왜 이게 이제 기본이 됐나

경험적으로 MHA에서 KV head 8개의 GQA로 가는 건 품질을 거의 안 잃고 긴 컨텍스트에서 거대한 메모리 절감. MQA는 더 공격적이지만 일부 task에서 품질 저하 — 그래서 분야는 GQA를 기본으로 정착. MHA 체크포인트를 GQA로 uptrain도 가능 — 각 그룹 안에서 KV head를 평균내고 원래 사전학습 컴퓨트의 약 5%로 학습 계속. Llama 2 70B-chat → GQA가 이렇게 만들어졌어.

Code

GQA in PyTorch (Llama-style)·python
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_q_heads, n_kv_heads):
        super().__init__()
        assert n_q_heads % n_kv_heads == 0
        self.n_q_heads, self.n_kv_heads = n_q_heads, n_kv_heads
        self.d_head = d_model // n_q_heads
        self.W_q = nn.Linear(d_model, n_q_heads * self.d_head, bias=False)
        self.W_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.repeat = n_q_heads // n_kv_heads
    def forward(self, x, mask=None):
        B, L, _ = x.shape
        Q = self.W_q(x).view(B, L, self.n_q_heads, self.d_head)
        K = self.W_k(x).view(B, L, self.n_kv_heads, self.d_head)
        V = self.W_v(x).view(B, L, self.n_kv_heads, self.d_head)
        # Repeat KV heads to match query heads
        K = K.repeat_interleave(self.repeat, dim=2)
        V = V.repeat_interleave(self.repeat, dim=2)
        Q, K, V = Q.transpose(1,2), K.transpose(1,2), V.transpose(1,2)
        out = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)
        return self.W_o(out.transpose(1,2).contiguous().view(B, L, -1))

External links

Exercise

GroupedQueryAttention 구현. 같은 shape의 MHA와 (B=4, n=8K) 입력에서 forward-pass 메모리 사용 비교. n=32K에서 반복. KV-cache 절감이 명확해야 해. 그 다음 각각으로 128K 컨텍스트가 얼마나 들지 추정.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.