C.W.K.
Stream
Lesson 05 of 05 · published

JIT 벤치마킹: 실제 속도 측정

~9 min · jit, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JIT 의 효과를 측정한다 — 단순한 일 같은데 함정이 많아. 가장 큰 함정: async dispatch.

JAX 호출은 default 로 async — array 를 반환하지만 실제 계산은 백그라운드에서 진행. time.time() 으로 측정하면 — Python 이 ndarray 받은 시각이지, 계산이 끝난 시각이 아님.

import jax
import jax.numpy as jnp
import time

@jax.jit
def f(x):
    return jnp.sum(x ** 2 + jnp.sin(x))

x = jnp.arange(1_000_000.)

# 잘못된 측정
t = time.time()
for _ in range(100):
    y = f(x)   # async 반환
print(f"잘못된 측정: {time.time()-t:.4f}s")

# 올바른 측정 — block_until_ready
t = time.time()
for _ in range(100):
    y = f(x).block_until_ready()
print(f"올바른 측정: {time.time()-t:.4f}s")

차이가 10 배 이상 나기도 함. 항상 .block_until_ready() 또는 jax.block_until_ready(...).

compile 시간 vs run 시간 분리:

# Warm-up: compile + cache
y = f(x).block_until_ready()

# 이제 run 시간만
t = time.time()
for _ in range(1000):
    y = f(x).block_until_ready()
elapsed = time.time() - t
print(f"호출 당: {elapsed/1000*1000:.3f}ms")

전형적 결과 (1M 원소, CPU):

  • NumPy: ~ 4 ms
  • JAX eager (jit 없음): ~ 6 ms — 약간 느림 (overhead)
  • JAX jit, 첫 호출: ~ 200 ms (compile 포함)
  • JAX jit, 후속 호출: ~ 0.5 ms — 8x 빠름

GPU 면 더 극적: cold ~ 500ms, warm ~ 50us — 80x 이상.

벤치마크 스크립트:

def benchmark(f, *args, n_warmup=3, n_runs=100, label=""):
    '''함수의 평균 호출 시간 측정. async dispatch 처리.'''
    # warm-up (compile)
    for _ in range(n_warmup):
        out = f(*args)
        if hasattr(out, "block_until_ready"):
            out.block_until_ready()

    # measure
    t = time.time()
    for _ in range(n_runs):
        out = f(*args)
        if hasattr(out, "block_until_ready"):
            out.block_until_ready()
    elapsed = (time.time() - t) / n_runs * 1000
    print(f"{label}: {elapsed:.3f}ms / call")

import numpy as np
x_np = np.random.randn(1_000_000).astype(np.float32)
x_jax = jnp.array(x_np)

benchmark(lambda x: np.sum(x**2 + np.sin(x)), x_np, label="NumPy")
benchmark(jax.jit(lambda x: jnp.sum(x**2 + jnp.sin(x))), x_jax, label="JAX jit")

⚠️ 벤치마크의 함정

(1) async dispatch — block_until_ready 빠뜨리면 가짜 빠름. (2) cold vs warm — 첫 호출은 compile 포함. (3) GPU 의 lazy launch — 실제 GPU 작업은 host 한 발짝 늦게 시작. (4) print 가 측정 루프 안에 있으면 — print 가 동기화 점이 됨. (5) 작은 array (< 1000 원소) 는 overhead 가 지배적이라 JAX 가 NumPy 보다 느릴 수도 있음.

실용 결론: 큰 array, 무거운 연산, 여러 번 호출 — 셋 다 만족하면 jit 효과가 극적. 한 번 호출 / 작은 데이터 / 단순 연산이면 jit 안 붙여도 됨. 그러나 학습 loop 의 step 함수는 거의 항상 jit 한다.

Code

import jax
import jax.numpy as jnp
import time

def matrix_computation(A, B):
    """A non-trivial computation to benchmark."""
    C = A @ B
    D = jnp.sin(C) + jnp.cos(C)
    E = D @ D.T
    return jnp.sum(jnp.log(jnp.abs(E) + 1e-7))

# Create test data
key = jax.random.PRNGKey(0)
A = jax.random.normal(key, (1000, 1000))
B = jax.random.normal(jax.random.PRNGKey(1), (1000, 1000))

# Benchmark WITHOUT jit
# Important: block_until_ready() ensures the computation is actually done
def time_fn(fn, *args, n_runs=10):
    # Warm up
    result = fn(*args)
    result.block_until_ready()

    start = time.perf_counter()
    for _ in range(n_runs):
        result = fn(*args)
        result.block_until_ready()
    elapsed = time.perf_counter() - start
    return elapsed / n_runs

eager_time = time_fn(matrix_computation, A, B)
print(f"Eager: {eager_time * 1000:.2f} ms")

# Benchmark WITH jit
jitted_computation = jax.jit(matrix_computation)

# Warm up (trigger compilation)
jitted_computation(A, B).block_until_ready()

jit_time = time_fn(jitted_computation, A, B)
print(f"JIT:   {jit_time * 1000:.2f} ms")
print(f"Speedup: {eager_time / jit_time:.1f}x")
import jax
import jax.numpy as jnp
import time

@jax.jit
def big_fn(x):
    for _ in range(20):
        x = jnp.sin(x) + jnp.cos(x)
    return jnp.sum(x)

x = jnp.ones((2000, 2000))

# First call includes compilation
start = time.perf_counter()
result = big_fn(x)
result.block_until_ready()
first_call = time.perf_counter() - start

# Second call uses cache
start = time.perf_counter()
result = big_fn(x)
result.block_until_ready()
second_call = time.perf_counter() - start

print(f"First call (compile + run): {first_call * 1000:.1f} ms")
print(f"Second call (run only):     {second_call * 1000:.1f} ms")
print(f"Compilation overhead:       ~{(first_call - second_call) * 1000:.1f} ms")

External links

Exercise

같은 forward pass 를 block_until_ready() 있을 때와 없을 때 측정. apparent speedup 계산. apparent — async dispatch 가 eager 숫자를 부풀려서. 적절한 blocking 으로 재측정. 정확한 숫자가 충격적 — 그게 진짜.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.