jit 의 한계 — Python 의 if/for 같은 control flow 를 trace 시점의 abstract value 로는 결정할 수 없어. 두 가지 도구로 우회:
1. static_argnames / static_argnums — 인자를 "trace 의 일부가 아닌, cache key 의 일부" 로 표시.
from functools import partial
@partial(jax.jit, static_argnames=("mode",))
def f(x, mode):
if mode == "train": # OK — mode 는 concrete (static)
return x * 2
else:
return x ** 2
f(jnp.array([1., 2.]), mode="train") # compile 1
f(jnp.array([1., 2.]), mode="eval") # compile 2
f(jnp.array([3., 4.]), mode="train") # cache hit (compile 1)
주의 — static 인자의 값이 바뀔 때마다 새 compile. 너무 다양한 값으로 부르면 cache 폭발 → 메모리 누수처럼 보일 수 있음.
2. jax.lax.cond / jax.lax.switch — runtime 분기를 IR 안에서 표현.
import jax.lax as lax
@jax.jit
def f(x):
return lax.cond(
x > 0, # predicate (traced 가능)
lambda x: x * 2, # true branch
lambda x: x ** 2, # false branch
x, # operand
)
f(jnp.array(3.0)) # 6
f(jnp.array(-2.0)) # 4
둘 다 single jit 으로, runtime 분기. switch 는 N way:
def f(idx, x):
return lax.switch(
idx,
[lambda x: x,
lambda x: x ** 2,
lambda x: jnp.sin(x)],
x,
)
Loop 는 jax.lax.scan:
# Python for-loop — trace 시점에 unroll. 길면 compile 오래 걸림.
@jax.jit
def f_unroll(x):
for _ in range(100):
x = jnp.sin(x)
return x
# scan — IR 의 loop 로 compile. 길이 무관 한 번만 compile.
@jax.jit
def f_scan(x):
def step(carry, _):
return jnp.sin(carry), None
final, _ = jax.lax.scan(step, x, jnp.zeros(100))
return final
scan 의 의미: step(carry, x_t) → (new_carry, y_t). RNN, training loop, 누적 등 거의 모든 sequential 연산에 사용.
fori_loop / while_loop 도 있음:
# 정해진 횟수 — fori_loop
def body(i, x):
return x + jnp.sin(x * i)
result = jax.lax.fori_loop(0, 100, body, init_val)
# 조건 종료 — while_loop
def cond(state): return state["i"] < 100 and state["err"] > 1e-6
def body(state): ...
final = jax.lax.while_loop(cond, body, initial_state)
⚠️ Python for vs lax.scan
짧은 loop (10 회 이하) 는 Python for 로 unroll 해도 OK — compile 시간 거의 같음. 그러나 100, 1000 step 학습 loop 를 Python for 로 jit 하면 — compile 이 분 단위로 걸리고 코드 사이즈가 폭발. 항상 scan 으로.
가이드라인:
- 인자가 dispatch flag 면 →
static_argnames - data-dependent 분기면 →
jax.lax.cond/jnp.where - 고정 길이 loop, 짧음 (≤ 10) → Python
for - 고정 길이 loop, 김 →
jax.lax.fori_loop또는scan - 가변 길이 / 조건 종료 →
jax.lax.while_loop - step 마다 output 수집 →
jax.lax.scan