C.W.K.
Stream
Lesson 03 of 13 · published

Scaled dot-product attention — Q · K^T → softmax → · V

~14 min · scaled-dot-product, core

Level 0Token
0 XP0/94 lessons0/10 achievements
0/120 XP to next level120 XP to go0% complete

attention 연산의 핵심은 3단계. 외워.

  1. Score: scores = Q @ K.T / sqrt(d_k). 각 (i, j) 항목이 토큰 i가 토큰 j에 얼마나 attend해야 할지 측정. shape (n, n).
  2. Normalize: weights = softmax(scores, axis=-1). 각 행이 합 1인 확률 분포가 됨. 토큰 i의 attention이 모든 위치에 분산.
  3. Weighted sum: out = weights @ V. 각 행이 attention으로 가중된 value 벡터들의 convex combination. shape (n, d_v).

matmul 셋 + softmax. 메커니즘 전체가 이거야. multi-head, causal masking, RoPE, GQA, Flash Attention — 모던 attention의 모든 복잡성이 이 3단계 위에 얹힌 거지, 이걸 바꾸진 않아.

Code

Scaled dot-product attention from scratch·python
import torch, math

def scaled_dot_product_attention(Q, K, V, mask=None):
    # Q, K, V: (..., seq_len, d_k or d_v)
    d_k = Q.size(-1)
    scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)   # (..., n, n)
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))
    weights = torch.softmax(scores, dim=-1)              # row-wise softmax
    return weights @ V                                    # (..., n, d_v)

# Equivalent to torch.nn.functional.scaled_dot_product_attention,
# minus the optimized kernels.
Verify against PyTorch's optimized kernel·python
import torch.nn.functional as F

Q = torch.randn(2, 4, 16, 64)
K = torch.randn(2, 4, 16, 64)
V = torch.randn(2, 4, 16, 64)

ref = F.scaled_dot_product_attention(Q, K, V)
mine = scaled_dot_product_attention(Q, K, V)

print(torch.allclose(ref, mine, atol=1e-5))   # True

External links

Exercise

scaled_dot_product_attention을 처음부터 구현. (B=2, n=8, d_k=32) 랜덤 텐서로 테스트. F.scaled_dot_product_attention이랑 비교. 그 다음 다시 안 돌리고, 1/sqrt(d_k) 스케일링을 제거하면 무슨 일이 생길지 예측. 테스트해. 모델 출력이 알아볼 수 있게 변했나?

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.