C.W.K.
Stream
Lesson 10 of 13 · published

KV-cache — 생성이 매번 전부 재계산 안 하는 이유

~14 min · kv-cache, inference

Level 0Token
0 XP0/94 lessons0/10 achievements
0/120 XP to next level120 XP to go0% complete

autoregressive 생성 중 모델은 한 번에 토큰 하나씩 생산. 순진하게 하면 새 토큰마다 전체 시퀀스를 모든 layer에 통과 — 토큰당 O(n²) 연산, 너무 느려. KV-cache가 이걸 피해. 아이디어는 기계적이고 아름다워 — 과거 토큰의 K, V 벡터는 일단 계산되면 절대 안 변해, 캐싱.

동작 방식

스텝 t에서 새 토큰의 Q만 계산하면 돼. 새 토큰의 K와 V는 계산해서 캐시에 append. 새 토큰의 attention은 캐시된 K, V에 대한 단일 matmul — Q는 shape (1, d_head), K_cache는 (t, d_head), 그래서 score는 (1, t), 출력은 (1, d_v). 토큰당 연산이 O(n²)에서 O(n)으로.

비용: 메모리

캐시는 2 × n_layers × n_kv_heads × seq_len × d_head × bytes_per_param 바이트 저장. Llama 3.3(70B, 80 layer, 8 KV head, d_head=128) FP16에서 128K 컨텍스트 시 대략 32 GB. 거대 — 종종 모델 가중치만 하거나 더 커. 긴 컨텍스트 추론이 어려운 주요 이유이자 GQA가 존재하는 주요 이유.

Code

KV-cache pseudocode·python
class KVCache:
    def __init__(self, max_len, n_kv_heads, d_head, n_layers, dtype, device):
        # one cache per layer; here just one for clarity
        self.K = torch.zeros(max_len, n_kv_heads, d_head, dtype=dtype, device=device)
        self.V = torch.zeros(max_len, n_kv_heads, d_head, dtype=dtype, device=device)
        self.length = 0
    def append(self, k_new, v_new):
        # k_new, v_new: (1, n_kv_heads, d_head)
        self.K[self.length] = k_new
        self.V[self.length] = v_new
        self.length += 1
    def get(self):
        return self.K[:self.length], self.V[:self.length]

def step_with_cache(token_id, cache, model):
    q = model.q_proj(model.embed(token_id))    # (1, n_q_heads, d_head)
    k = model.k_proj(model.embed(token_id))    # (1, n_kv_heads, d_head)
    v = model.v_proj(model.embed(token_id))
    cache.append(k, v)
    K, V = cache.get()
    # attention against full cached K, V — this is O(t), not O(t^2)
    scores = q @ K.transpose(-2, -1) / d_head**0.5
    weights = torch.softmax(scores, dim=-1)
    return weights @ V
KV cache size estimator·python
def kv_cache_bytes(n_layers, n_kv_heads, d_head, seq_len, bytes_per_param=2):
    # 2 because we store both K and V
    return 2 * n_layers * n_kv_heads * d_head * seq_len * bytes_per_param

# Llama 3.3 70B at 128K context, FP16
b = kv_cache_bytes(80, 8, 128, 128_000, bytes_per_param=2)
print(f"{b / 1e9:.1f} GB")     # ~32 GB

External links

Exercise

작은 GPT 스타일 모델 구현하고 생성 함수 두 개 작성 — 하나는 KV-cache 없이(매 스텝 전부 재계산), 하나는 KV-cache 사용. 시퀀스 길이 256, 512, 1024, 2048에서 둘 다 시간 측정. 토큰당 latency 플롯. 캐시 오버헤드가 가치 있어지는 교차점이 명확해야 해.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.