C.W.K.
Stream
Lesson 01 of 05 · published

jit 의 일: Tracing, Compiling, Caching

~8 min · jit, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

jax.jit 이 실제로 무얼 하는지 — 한 번 보면 모든 게 명확해져. 3 단계로 봐:

1. Trace — 첫 호출 때 함수를 abstract 한 Tracer object 로 한 번 실행. 어떤 연산이 어떤 순서로 일어나는지 IR (XLA HLO) 로 기록.

2. Compile — IR 을 XLA compiler 로 넘김. 가속기 별 native code 생성.

3. Cache — compile 결과를 (input shapes, dtypes, static args) 키로 캐시. 다음 호출 때 같은 키면 재사용.

import jax
import jax.numpy as jnp
import time

@jax.jit
def f(x):
    return jnp.sum(x ** 2 + jnp.sin(x))

x = jnp.arange(1_000_000.0)

# 첫 호출 — trace + compile + run
t = time.time()
y = f(x).block_until_ready()
print(f"첫 호출: {time.time()-t:.3f}s")  # ~ 0.3s

# 두 번째 — cache hit, 빠름
t = time.time()
y = f(x).block_until_ready()
print(f"두 번째: {time.time()-t:.3f}s")  # ~ 0.001s

# 다른 shape — 새 trace + compile
y = f(jnp.arange(2_000_000.0)).block_until_ready()

cache key 의 핵심:

  • Shape: (1000,)(2000,) 는 다른 키 → recompile.
  • Dtype: float32 와 float64 는 다른 키.
  • Static args: static_argnames 로 표시한 인자의 값 변화도 recompile.
  • Device: 같은 함수도 CPU 와 GPU 는 별도 cache.

실험으로 보면:

def trace_count_demo():
    n_traces = 0
    @jax.jit
    def f(x):
        nonlocal n_traces
        n_traces += 1   # trace 시점에만 +1
        return x * 2

    f(jnp.arange(10.))   # trace + compile + run
    f(jnp.arange(10.))   # cache hit
    f(jnp.arange(20.))   # 다른 shape → recompile
    f(jnp.arange(20., dtype=jnp.float64))  # 다른 dtype → recompile
    print(f"trace 횟수: {n_traces}")

⚡ "compile 한 번, 호출 여러 번" 모델

JAX 의 성능 모델 핵심 — 함수를 처음 한 번 compile 하고 그 후 수 천 번 빠르게 부른다. 학습 loop 라면 — train_step 함수가 첫 step 에서 compile 되고 나머지 1 만 step 은 cached 호출. compile 비용 (한 번 1 초) 은 학습 시간 (1 시간) 에 비해 무시할 만함. 그래서 jit 은 거의 항상 이득.

중요한 함정 — shape 이 호출마다 바뀌면 매번 recompile. 학습 마지막 batch 가 미세하게 작은 거 (drop_last=False) — 그것 때문에 매 epoch 마다 compile 한 번 더 할 수도 있어. 해결: pad 해서 shape 동일하게 유지.

cache 비우려면 jax.clear_caches() — 보통 안 써도 됨. 메모리 압박 있을 때만.

Code

import jax
import jax.numpy as jnp

def slow_fn(x):
    """Each operation launches a separate kernel."""
    y = jnp.sin(x)
    z = jnp.cos(x)
    return jnp.sum(y * z + y ** 2)

# Wrap with jit
fast_fn = jax.jit(slow_fn)

# Or use as a decorator
@jax.jit
def fast_fn_v2(x):
    y = jnp.sin(x)
    z = jnp.cos(x)
    return jnp.sum(y * z + y ** 2)

x = jnp.ones(10000)
result = fast_fn(x)  # First call: trace + compile + execute
result = fast_fn(x)  # Second call: execute cached compiled code (much faster)
import jax
import jax.numpy as jnp

@jax.jit
def add_one(x):
    print("TRACING!")  # This print helps us see when tracing happens
    return x + 1

# Call 1: traces + compiles + executes
add_one(jnp.array([1.0, 2.0]))  # Prints "TRACING!"

# Call 2: same shape → uses cache, no retracing
add_one(jnp.array([3.0, 4.0]))  # No print — cached!

# Call 3: different shape → must retrace
add_one(jnp.array([1.0, 2.0, 3.0]))  # Prints "TRACING!" — new shape

External links

Exercise

같은 matrix multiply 3 가지로 측정: pure jnp (jit 없음), jit cold (첫 호출), jit warm (이후). matrix size 64, 256, 1024, 2048 의 세 숫자 plot 또는 print. 곡선 읽기 — jit 의 payoff 가 공짜 아냐.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.