C.W.K.
Stream
Lesson 01 of 04 · published

왜 AI는 두 primitive로 무너지는가

~12 min · linalg, dot-product, gemm, primitive

Level 0Beginner
0 XP0/38 lessons0/12 achievements
0/100 XP to next level100 XP to go0% complete

컴파일된 AI 그래프 보면 op 두 개가 보여

Triton 커널, XLA HLO dump, PyTorch torch.compile 출력 열어서 ops 훑어. Convolution, attention, MLP, RNN, GNN — 학습 layer 전부 결국 두 primitive 조합으로 분해돼:

Primitive수학실제로 보이는 곳
Dot Producty = Σ xᵢ · wᵢAttention score, vector search의 cosine similarity, neuron 하나의 출력
Matrix Multiply (GEMM)C = α·A·B + β·CDense / Linear layer, Q-K-V projection, MLP block, batched embedding lookup

Convolution은 im2col이나 implicit-gemm 거쳐서 GEMM이 돼. Softmax는 element-wise 수학 + reduction. LayerNorm은 reduction 둘 + element-wise scale. '진짜' compute, FLOP 90% 들어가는 부분은 GEMM이야.

그래서 하드웨어 vendor랑 라이브러리 저자들이 모양 하나에 수십 년을 쏟아부어 — 빠른 GEMM. NVIDIA Tensor Core, Apple matrix coprocessor, M1+ AMX, AMD Matrix Core. 다 GPU에 박힌 더 큰, 더 빠른 GEMM 가속기야.

이거 내재화하면 모델 디자인이 달라져. '헤드 더 추가할까?' 가 'GEMM 모양이 어떻게 바뀌지?' 가 돼. Padding, batching, hidden-size 선택 — 다 GEMM 라이브러리가 깔끔하게 vectorize할 수 있는 모양 먹이고 있느냐로 무너져.

Code

PyTorch에서 두 primitive — 모든 모델의 재료·python
import torch

# Dot product — neuron 한 개 출력, attention score
x = torch.randn(1024)
w = torch.randn(1024)
y = torch.dot(x, w)         # scalar

# Matrix multiply — 모든 Linear layer, Q/K/V projection, MLP
A = torch.randn(64, 1024)   # 시퀀스 64 × hidden 1024
B = torch.randn(1024, 4096) # weight 행렬, hidden → ffn
C = A @ B                   # (64, 4096)

# torch.compile이 vendor 최적화 fused 커널로 lower.
# 보닛 안: NVIDIA cuBLAS, Apple MPS, Intel oneDNN.
MLX (Apple)에서 같은 아이디어 — 같은 두 primitive, 다른 runtime·python
import mlx.core as mx

x = mx.random.normal((1024,))
w = mx.random.normal((1024,))
y = mx.matmul(x, w)         # 같은 matmul 커널로 scalar

A = mx.random.normal((64, 1024))
B = mx.random.normal((1024, 4096))
C = A @ B                   # MLX는 MPSMatrixMultiplication에 dispatch
mx.eval(C)                  # MLX는 lazy; eval이 실행 강제.

External links

Exercise

PyTorch 스크립트 아무거나에서 torch._dynamo.config.verbose = True 설정하고 작은 모델 (2-layer MLP면 충분)에 torch.compile 돌려. trace 출력에서 'matmul' / 'mm' / 'addmm' 단어 빈도 훑어. 다른 op 이름이랑 비교해서 얼마나 자주 보이는지 — 그래서 GEMM이 최적화 예산 다 가져가는 거.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.