실제로 만나는 jit 에러 — 5 가지가 95% 차지. 각각 패턴 + 해결:
1. Tracer 에 if
@jax.jit
def f(x):
if x.sum() > 0: # ❌ ConcretizationTypeError
return x
return -x
# Fix
@jax.jit
def f(x):
return jnp.where(x.sum() > 0, x, -x)
2. Python loop 가 길거나 가변
@jax.jit
def f(x, n): # n 이 traced 면 안 됨
for _ in range(n):
x = x + 1
return x
# Fix 1: n 이 정해진 작은 정수
@jax.jit
def f(x):
for _ in range(10):
x = x + 1
return x
# Fix 2: scan
def f(x, n):
def step(carry, _):
return carry + 1, None
return jax.lax.fori_loop(0, n, lambda i, c: c + 1, x)
3. Python list / dict 의 mutation
@jax.jit
def f(x):
out = []
for i in range(3):
out.append(x[i] * 2) # OK if 길이 static
return jnp.stack(out)
# 더 좋은 방식: 한 번에
@jax.jit
def f(x):
return x[:3] * 2
4. shape 이 호출마다 변함
f = jax.jit(lambda x: x ** 2)
f(jnp.arange(10.)) # compile 1
f(jnp.arange(20.)) # compile 2 — 의도한 거?
# Fix: padding
def pad_to_max(x, max_len=1024):
return jnp.pad(x, (0, max_len - x.shape[0]))
f(pad_to_max(short_x)) # 항상 같은 shape
f(pad_to_max(longer_x)) # cache hit
학습 시 — drop_last=True 로 batch size 일정하게 유지하거나, masking 으로 padding 처리.
5. static_argnames 가 hashable 이 아님
@partial(jax.jit, static_argnames=("shape",))
def f(x, shape):
return jnp.zeros(shape) + x.sum()
f(x, [3, 3]) # ❌ list 는 unhashable
f(x, (3, 3)) # ✅
static 인자는 hashable + comparable 이어야 함 (tuple, str, int, frozenset, hashable dataclass).
🐛 디버깅 순서
(1) jit 빼고 eager 로 돌려 봐 — pure Python error 처럼 잡힘. (2) jax.make_jaxpr(f)(x) 로 trace 만 해 봐 — compile 직전까지 도달. (3) jax.config.update("jax_disable_jit", True) 로 전체 비활성화. (4) 작은 input 으로 reduce 하면서 break point 찾기.
실용 도구:
import jax
import os
# 환경 변수로 jit 끄기
os.environ["JAX_DISABLE_JIT"] = "1"
# 또는 context manager
with jax.disable_jit():
f(x) # eager mode
# trace 횟수 보기
@jax.jit
def f(x):
print("trace!") # trace 마다 한 번
return x ** 2
f(jnp.arange(10.)) # trace!
f(jnp.arange(10.)) # (cache hit, no print)
f(jnp.arange(20.)) # trace!
처음 jit 다룰 땐 — trace count 가 의도와 다르면 그게 가장 흔한 버그 근원이야.