JAX 의 에러 메시지가 처음엔 외계어처럼 보이는데 — 패턴 익히면 빠르게 해석 가능.
1. ConcretizationTypeError
jax.errors.ConcretizationTypeError: Abstract tracer value encountered
where concrete value is expected: Traced<ShapedArray(float32[])>...
의미: traced 값을 Python 의 concrete 값처럼 다루려 함 (예: if x > 0:, .item(), int(x)).
@jax.jit
def f(x):
if x > 0: # ❌ Python if on Tracer
return x
else:
return -x
해결:
@jax.jit
def f(x):
return jnp.where(x > 0, x, -x) # ✅
# 또는
@jax.jit
def f(x):
return jax.lax.cond(x > 0, lambda: x, lambda: -x)
2. TracerArrayConversionError
TracerArrayConversionError: The numpy.ndarray conversion method
was called on Traced<...>
의미: traced 값을 numpy 함수에 넣음. numpy 가 concrete array 를 원하는데 Tracer 받았다.
@jax.jit
def f(x):
return np.sin(x) # ❌ np 대신 jnp
해결: jnp 로.
3. Cached compile, side effect 안 도는 증상
에러 안 남. 결과가 이상함.
global_counter = 0
@jax.jit
def f(x):
global global_counter
global_counter += 1
return x * 2
f(1.0); f(2.0); f(3.0)
print(global_counter) # 1 (3 아님!)
해결: state 를 함수 인자로 노출. def f(x, counter): return x*2, counter+1.
4. NonHashableStaticArgumentsError
TypeError: unhashable type: 'list'
의미: static_argnames 로 표시된 인자가 hashable 이어야 함 (list 는 X, tuple 은 O).
@partial(jax.jit, static_argnames=("shape",))
def f(x, shape):
return jnp.zeros(shape) + x
f(1.0, [3, 3]) # ❌ list
f(1.0, (3, 3)) # ✅ tuple
5. ShapeMismatchError
jax 에서 두 호출의 shape 이 달라서 recompile 됨
의미: 같은 jit 함수를 다른 shape 으로 부름 → 매번 새로 compile. 의도한 거면 OK, 아니면 static_argnames 로 분리하거나 padding.
🔍 디버깅 routine
(1) 에러 메시지의 첫 줄만 봐도 카테고리 잡힘. ConcretizationTypeError = control flow on traced value. TracerArrayConversion = wrong namespace. (2) jit 빼고 eager 모드로 돌려 봐 — 그러면 일반 Python error 처럼 동작해서 원인 찾기 쉬움. (3) jax.disable_jit() context manager 로 감싸도 됨.
실용 팁: 새 함수를 처음 짤 땐 jit 안 붙이고 eager 모드로 한 번 돌려서 정상 동작 확인 후 jit 추가. compile 에러는 까다로워서 — 한 번에 다 잡으려고 하면 시간 낭비.