JAX 의 빠름은 운으로 오는 게 아니야. XLA (Accelerated Linear Algebra) — Google 이 만든 도메인 특화 compiler 가 핵심이야. JAX 함수를 trace 해서 XLA HLO IR 로 변환하고, 거기서 fusion / layout 최적화 / 디바이스 별 코드 생성을 다 해.
- NumPy: CPU 만. C/Fortran 백엔드. GPU 못 씀.
- jax.numpy: XLA 가 결정. CPU 면 LLVM 최적화, GPU 면 cuDNN/cuBLAS, TPU 면 TPU instruction.
중요한 게 fusion. NumPy 에서 (x*2 + 1) ** 0.5 쓰면 — 메모리 중간 결과 3 번 만들어졌다 사라져. XLA 는 single fused kernel 로 합쳐.
import jax
import jax.numpy as jnp
import time
x_jax = jnp.zeros((2048, 2048))
@jax.jit
def f(x):
return jnp.sin(x @ x) + jnp.cos(x @ x)
f(x_jax).block_until_ready() # warm up
t = time.time()
for _ in range(10):
f(x_jax).block_until_ready()
print(f"JAX: {time.time()-t:.3f}s")
🔬 mental model
JAX 코드를 짤 땐 — "내가 짜는 건 NumPy 코드, 그런데 실행은 XLA 가 적절한 hardware 에서 한다" 라고 생각해. jax.devices() 로 어떤 device 가 잡혔는지 확인.