autoregressive 생성 중 모델은 한 번에 토큰 하나씩 생산. 순진하게 하면 새 토큰마다 전체 시퀀스를 모든 layer에 통과 — 토큰당 O(n²) 연산, 너무 느려. KV-cache가 이걸 피해. 아이디어는 기계적이고 아름다워 — 과거 토큰의 K, V 벡터는 일단 계산되면 절대 안 변해, 캐싱.
동작 방식
스텝 t에서 새 토큰의 Q만 계산하면 돼. 새 토큰의 K와 V는 계산해서 캐시에 append. 새 토큰의 attention은 캐시된 K, V에 대한 단일 matmul — Q는 shape (1, d_head), K_cache는 (t, d_head), 그래서 score는 (1, t), 출력은 (1, d_v). 토큰당 연산이 O(n²)에서 O(n)으로.
비용: 메모리
캐시는 2 × n_layers × n_kv_heads × seq_len × d_head × bytes_per_param 바이트 저장. Llama 3.3(70B, 80 layer, 8 KV head, d_head=128) FP16에서 128K 컨텍스트 시 대략 32 GB. 거대 — 종종 모델 가중치만 하거나 더 커. 긴 컨텍스트 추론이 어려운 주요 이유이자 GQA가 존재하는 주요 이유.
Code
KV-cache pseudocode·python
class KVCache:
def __init__(self, max_len, n_kv_heads, d_head, n_layers, dtype, device):
# one cache per layer; here just one for clarity
self.K = torch.zeros(max_len, n_kv_heads, d_head, dtype=dtype, device=device)
self.V = torch.zeros(max_len, n_kv_heads, d_head, dtype=dtype, device=device)
self.length = 0
def append(self, k_new, v_new):
# k_new, v_new: (1, n_kv_heads, d_head)
self.K[self.length] = k_new
self.V[self.length] = v_new
self.length += 1
def get(self):
return self.K[:self.length], self.V[:self.length]
def step_with_cache(token_id, cache, model):
q = model.q_proj(model.embed(token_id)) # (1, n_q_heads, d_head)
k = model.k_proj(model.embed(token_id)) # (1, n_kv_heads, d_head)
v = model.v_proj(model.embed(token_id))
cache.append(k, v)
K, V = cache.get()
# attention against full cached K, V — this is O(t), not O(t^2)
scores = q @ K.transpose(-2, -1) / d_head**0.5
weights = torch.softmax(scores, dim=-1)
return weights @ V
KV cache size estimator·python
def kv_cache_bytes(n_layers, n_kv_heads, d_head, seq_len, bytes_per_param=2):
# 2 because we store both K and V
return 2 * n_layers * n_kv_heads * d_head * seq_len * bytes_per_param
# Llama 3.3 70B at 128K context, FP16
b = kv_cache_bytes(80, 8, 128, 128_000, bytes_per_param=2)
print(f"{b / 1e9:.1f} GB") # ~32 GB
작은 GPT 스타일 모델 구현하고 생성 함수 두 개 작성 — 하나는 KV-cache 없이(매 스텝 전부 재계산), 하나는 KV-cache 사용. 시퀀스 길이 256, 512, 1024, 2048에서 둘 다 시간 측정. 토큰당 latency 플롯. 캐시 오버헤드가 가치 있어지는 교차점이 명확해야 해.
Progress
Progress is local-only — sign in to sync across devices.