C.W.K.
Stream
Lesson 02 of 10 · published

병렬화의 벽: GPU가 RNN을 싫어한 이유

~18 min · parallelism, gpu, rnn

Level 0Token
0 XP0/94 lessons0/10 achievements
0/120 XP to next level120 XP to go0% complete

요즘 GPU는 극단적으로 병렬화된 칩이야. A100만 봐도 CUDA core 6,912개 + 텐서 코어 432개. H100은 거기서 더 나갔지. "한 클럭에 수천 번의 곱셈-덧셈을 동시에" — 이게 GPU의 인격이야. 그런데 RNN은 시퀀스 축 방향으로는 본질적으로 직렬이라, 코어가 아무리 많아도 도와줄 수가 없어.

구현 문제가 아니라 의존 관계 문제야. h[t]를 구하려면 h[t-1]이 필요하고, 그건 h[t-2]가 필요하고… 결국 h[0]까지 거슬러 올라가는 길이 n짜리 사슬이 있어. 이걸 한 칸씩 걷는 거 외엔 방법이 없어.

self-attention은 이 사슬을 그냥 우회해. 모든 위치가 같은 입력 임베딩 집합으로부터 계산되니까, n개의 위치별 업데이트가 단일 행렬곱 안에서 한꺼번에 일어나. 시퀀스 길이를 두 배로 늘려도 wall-clock 시간이 두 배가 아니야 — matmul 크기만 커지는데, GPU가 가장 좋아하는 게 큰 matmul이거든.

2017년 논문의 영수증

원조 트랜스포머 논문 보면, base 모델은 P100 8장으로 12시간 만에 학습 끝(총 3.3 × 10¹⁸ FLOPs), big 모델은 3.5일. 비슷한 시기 deep LSTM 기반 번역 시스템들은 같은 품질 내려고 한 자릿수 더 걸렸어. FLOP 개수만이 아니라, 아키텍처 모양이 하드웨어 활용률을 결정한 거야.

Code

Sequential vs parallel — the wall-clock difference·python
# RNN: cannot start step t until step t-1 is done
for t in range(seq_len):
    h[t] = f(h[t-1], x[t])      # sequential dependency

# Self-attention: all positions in one matmul
# Q, K, V are (seq_len, d_k); attention is one matrix product
scores = Q @ K.T / sqrt(d_k)    # (seq_len, seq_len)
weights = softmax(scores)
out = weights @ V                # (seq_len, d_v)
# GPU lays out all seq_len positions across thousands of cores
Quick benchmark sketch (PyTorch)·python
import torch, time
n, d = 1024, 512
x = torch.randn(n, d, device='cuda')

# RNN-style sequential update
h = torch.zeros(d, device='cuda')
W = torch.randn(d, d, device='cuda')
torch.cuda.synchronize(); t0 = time.time()
for t in range(n):
    h = torch.tanh(W @ h + x[t])
torch.cuda.synchronize(); print('RNN-like:', time.time() - t0)

# Single matmul (attention-style)
W2 = torch.randn(d, d, device='cuda')
torch.cuda.synchronize(); t0 = time.time()
out = x @ W2
torch.cuda.synchronize(); print('Matmul:  ', time.time() - t0)
# Even ignoring softmax/divisions, the gap is enormous.

External links

Exercise

GPU 한 장에서 같은 작업을 두 번 프로파일링해 봐 — 2-layer LSTM 하나, 작은 Transformer encoder 하나. 둘 다 d_model=128, 가능한 곳엔 8 attention head 적용. 시퀀스 길이 256 / 1024 / 4096에서 tokens-per-second 측정해서 로그 스케일로 그래프 그려. 어디서 각 아키텍처가 더 이상 안 늘어나는지, 그 이유가 뭔지 설명해 봐.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.