JIT 의 효과를 측정한다 — 단순한 일 같은데 함정이 많아. 가장 큰 함정: async dispatch.
JAX 호출은 default 로 async — array 를 반환하지만 실제 계산은 백그라운드에서 진행. time.time() 으로 측정하면 — Python 이 ndarray 받은 시각이지, 계산이 끝난 시각이 아님.
import jax
import jax.numpy as jnp
import time
@jax.jit
def f(x):
return jnp.sum(x ** 2 + jnp.sin(x))
x = jnp.arange(1_000_000.)
# 잘못된 측정
t = time.time()
for _ in range(100):
y = f(x) # async 반환
print(f"잘못된 측정: {time.time()-t:.4f}s")
# 올바른 측정 — block_until_ready
t = time.time()
for _ in range(100):
y = f(x).block_until_ready()
print(f"올바른 측정: {time.time()-t:.4f}s")
차이가 10 배 이상 나기도 함. 항상 .block_until_ready() 또는 jax.block_until_ready(...).
compile 시간 vs run 시간 분리:
# Warm-up: compile + cache
y = f(x).block_until_ready()
# 이제 run 시간만
t = time.time()
for _ in range(1000):
y = f(x).block_until_ready()
elapsed = time.time() - t
print(f"호출 당: {elapsed/1000*1000:.3f}ms")
전형적 결과 (1M 원소, CPU):
- NumPy: ~ 4 ms
- JAX eager (jit 없음): ~ 6 ms — 약간 느림 (overhead)
- JAX jit, 첫 호출: ~ 200 ms (compile 포함)
- JAX jit, 후속 호출: ~ 0.5 ms — 8x 빠름
GPU 면 더 극적: cold ~ 500ms, warm ~ 50us — 80x 이상.
벤치마크 스크립트:
def benchmark(f, *args, n_warmup=3, n_runs=100, label=""):
'''함수의 평균 호출 시간 측정. async dispatch 처리.'''
# warm-up (compile)
for _ in range(n_warmup):
out = f(*args)
if hasattr(out, "block_until_ready"):
out.block_until_ready()
# measure
t = time.time()
for _ in range(n_runs):
out = f(*args)
if hasattr(out, "block_until_ready"):
out.block_until_ready()
elapsed = (time.time() - t) / n_runs * 1000
print(f"{label}: {elapsed:.3f}ms / call")
import numpy as np
x_np = np.random.randn(1_000_000).astype(np.float32)
x_jax = jnp.array(x_np)
benchmark(lambda x: np.sum(x**2 + np.sin(x)), x_np, label="NumPy")
benchmark(jax.jit(lambda x: jnp.sum(x**2 + jnp.sin(x))), x_jax, label="JAX jit")
⚠️ 벤치마크의 함정
(1) async dispatch — block_until_ready 빠뜨리면 가짜 빠름. (2) cold vs warm — 첫 호출은 compile 포함. (3) GPU 의 lazy launch — 실제 GPU 작업은 host 한 발짝 늦게 시작. (4) print 가 측정 루프 안에 있으면 — print 가 동기화 점이 됨. (5) 작은 array (< 1000 원소) 는 overhead 가 지배적이라 JAX 가 NumPy 보다 느릴 수도 있음.
실용 결론: 큰 array, 무거운 연산, 여러 번 호출 — 셋 다 만족하면 jit 효과가 극적. 한 번 호출 / 작은 데이터 / 단순 연산이면 jit 안 붙여도 됨. 그러나 학습 loop 의 step 함수는 거의 항상 jit 한다.