같은 C-level 호출, 세 Python 맛
대부분 시간엔 CUDA / Metal 안 쓰고 — 어딘가 밑에서 CUDA / Metal 호출 trigger하는 Python 써. 이 레슨이 Python 표현이 어느 BLAS routine으로 가는지 매핑.
| Python 표현 | BLAS 호출 | 어디 |
|---|---|---|
np.dot(x, y) | sdot / ddot (BLAS-1) | CPU의 OpenBLAS 또는 MKL |
np.dot(A, x) | sgemv / dgemv (BLAS-2) | CPU의 OpenBLAS 또는 MKL |
np.dot(A, B) 또는 A @ B | sgemm / dgemm (BLAS-3) | CPU의 OpenBLAS 또는 MKL |
CUDA tensor의 torch.matmul(A, B) | cublasSgemm / cublasGemmEx | cuBLAS |
MPS tensor의 torch.matmul(A, B) | MPSMatrixMultiplication encode/dispatch | MPS |
mx.matmul(A, B); mx.eval(C) | MLX runtime 통한 MPSMatrixMultiplication | MPS |
Python이 왜 안 느리게 만드나
Python 인터프리터가 BLAS 호출당 정확히 한 번 관여: 표현 파싱, C 확장에 dispatch. 실제 GEMM은 전부 C/CUDA/Metal 영역에서 돔. GPU 시간 1.5 ms 걸리는 4096³ GEMM에 Python 오버헤드는 microsecond — ~0.1%. 그래서 PyTorch의 A @ B가 손으로 쓴 C++ wrapper랑 같은 throughput.
뭐가 느리게 만드냐: 작은 호출 많이. 각 A @ B가 따로 dispatch. 큰 거 1번 대신 작은 matmul 1000번 하면 Python 오버헤드 millisecond로 누적. torch.compile, jax.jit, MLX의 lazy evaluation 다 작은 호출을 큰 거로 fuse하려고 존재.