C.W.K.
Stream
Lesson 12 of 13 · published

Efficient attention — Flash Attention, sliding window, sparse

~16 min · flash-attention, sliding-window, sparse

Level 0Token
0 XP0/94 lessons0/10 achievements
0/120 XP to next level120 XP to go0% complete

대규모에서 attention을 실용화하기 위한 세 계열 기법이 등장. 정확한 attention 보존하느냐, 일부 정확도를 효율과 교환하느냐로 갈려.

Flash Attention (정확, 메모리 효율)

Flash Attention(Dao et al. 2022)은 표준 공식과 동일한 attention 출력 계산, 하지만 풀 (n × n) score 행렬을 GPU HBM에 절대 만들지 않아. 대신 연산을 타일링 — Q 블록 하나, K/V 블록 하나 로드, online 정규화로 부분 softmax 계산, 부분 결과 기록, 반복. 약간의 재계산을 메모리 대역폭의 거대한 감소와 교환. Flash Attention 2가 병렬성 개선; Flash Attention 3가 H100용 FP8 + warp specialization 지원, 약 740 TFLOPs/s 달성.

Sliding-window attention (근사, local)

각 토큰을 W 토큰의 local window에만 attend하도록 제한. O(n²) 대신 O(n × W). Mistral 7B, Mixtral, Gemma 3가 사용. 모델이 직접 장거리 attention을 잃지만 실용적 긴 컨텍스트 효율을 얻어. 일부 장거리 신호 보존을 위해 몇 개의 "global" attention layer 또는 "sink" 토큰과 결합되는 경우 多.

Sparse attention (근사, 구조적)

구조화된 sparse 패턴에 대해서만 attention 계산 — 예: 대각 띠 + 소수 global 토큰 + 랜덤 연결(Longformer, BigBird). Phi-3-small 사용. O(n × log n) 또는 O(n) 얻는 대가로 근사 출력.

Code

Flash Attention is just a kernel call·python
import torch
import torch.nn.functional as F

# In modern PyTorch (>=2.0), F.scaled_dot_product_attention
# automatically dispatches to Flash Attention 2 or 3
# when the inputs are eligible (CUDA, FP16/BF16, etc).

Q = torch.randn(2, 32, 8192, 128, device='cuda', dtype=torch.float16)
K = torch.randn(2, 32, 8192, 128, device='cuda', dtype=torch.float16)
V = torch.randn(2, 32, 8192, 128, device='cuda', dtype=torch.float16)

with torch.backends.cuda.sdp_kernel(enable_flash=True):
    out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
# Same output as a textbook implementation, but never instantiates
# the full (8192, 8192) score matrix per head.
Sliding-window mask·python
def sliding_window_mask(seq_len, window):
    # Allow each position to attend to itself and `window` previous positions
    i = torch.arange(seq_len)[:, None]
    j = torch.arange(seq_len)[None, :]
    return (j > i) | (j < i - window + 1)   # True where masked

External links

Exercise

F.scaled_dot_product_attention(Flash Attention 켜고)랑 교과서 attention 구현을 (B=2, n_heads=32, seq_len=4096, d_head=128)에서 벤치. wall-clock 시간이랑 peak GPU 메모리 둘 다 측정. seq_len=16384에서 반복. 속도 개선과 메모리 절감 문서화.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.