jax.jit 이 실제로 무얼 하는지 — 한 번 보면 모든 게 명확해져. 3 단계로 봐:
1. Trace — 첫 호출 때 함수를 abstract 한 Tracer object 로 한 번 실행. 어떤 연산이 어떤 순서로 일어나는지 IR (XLA HLO) 로 기록.
2. Compile — IR 을 XLA compiler 로 넘김. 가속기 별 native code 생성.
3. Cache — compile 결과를 (input shapes, dtypes, static args) 키로 캐시. 다음 호출 때 같은 키면 재사용.
import jax
import jax.numpy as jnp
import time
@jax.jit
def f(x):
return jnp.sum(x ** 2 + jnp.sin(x))
x = jnp.arange(1_000_000.0)
# 첫 호출 — trace + compile + run
t = time.time()
y = f(x).block_until_ready()
print(f"첫 호출: {time.time()-t:.3f}s") # ~ 0.3s
# 두 번째 — cache hit, 빠름
t = time.time()
y = f(x).block_until_ready()
print(f"두 번째: {time.time()-t:.3f}s") # ~ 0.001s
# 다른 shape — 새 trace + compile
y = f(jnp.arange(2_000_000.0)).block_until_ready()
cache key 의 핵심:
- Shape:
(1000,)와(2000,)는 다른 키 → recompile. - Dtype: float32 와 float64 는 다른 키.
- Static args:
static_argnames로 표시한 인자의 값 변화도 recompile. - Device: 같은 함수도 CPU 와 GPU 는 별도 cache.
실험으로 보면:
def trace_count_demo():
n_traces = 0
@jax.jit
def f(x):
nonlocal n_traces
n_traces += 1 # trace 시점에만 +1
return x * 2
f(jnp.arange(10.)) # trace + compile + run
f(jnp.arange(10.)) # cache hit
f(jnp.arange(20.)) # 다른 shape → recompile
f(jnp.arange(20., dtype=jnp.float64)) # 다른 dtype → recompile
print(f"trace 횟수: {n_traces}")
⚡ "compile 한 번, 호출 여러 번" 모델
JAX 의 성능 모델 핵심 — 함수를 처음 한 번 compile 하고 그 후 수 천 번 빠르게 부른다. 학습 loop 라면 — train_step 함수가 첫 step 에서 compile 되고 나머지 1 만 step 은 cached 호출. compile 비용 (한 번 1 초) 은 학습 시간 (1 시간) 에 비해 무시할 만함. 그래서 jit 은 거의 항상 이득.
중요한 함정 — shape 이 호출마다 바뀌면 매번 recompile. 학습 마지막 batch 가 미세하게 작은 거 (drop_last=False) — 그것 때문에 매 epoch 마다 compile 한 번 더 할 수도 있어. 해결: pad 해서 shape 동일하게 유지.
cache 비우려면 jax.clear_caches() — 보통 안 써도 됨. 메모리 압박 있을 때만.