C.W.K.
Stream
Lesson 04 of 05 · published

Impure Code 디버깅: 에러 메시지 해독

~10 min · purity, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 의 에러 메시지가 처음엔 외계어처럼 보이는데 — 패턴 익히면 빠르게 해석 가능.

1. ConcretizationTypeError

jax.errors.ConcretizationTypeError: Abstract tracer value encountered
where concrete value is expected: Traced<ShapedArray(float32[])>...

의미: traced 값을 Python 의 concrete 값처럼 다루려 함 (예: if x > 0:, .item(), int(x)).

@jax.jit
def f(x):
    if x > 0:        # ❌ Python if on Tracer
        return x
    else:
        return -x

해결:

@jax.jit
def f(x):
    return jnp.where(x > 0, x, -x)  # ✅
# 또는
@jax.jit
def f(x):
    return jax.lax.cond(x > 0, lambda: x, lambda: -x)

2. TracerArrayConversionError

TracerArrayConversionError: The numpy.ndarray conversion method
was called on Traced<...>

의미: traced 값을 numpy 함수에 넣음. numpy 가 concrete array 를 원하는데 Tracer 받았다.

@jax.jit
def f(x):
    return np.sin(x)  # ❌ np 대신 jnp

해결: jnp 로.

3. Cached compile, side effect 안 도는 증상

에러 안 남. 결과가 이상함.

global_counter = 0
@jax.jit
def f(x):
    global global_counter
    global_counter += 1
    return x * 2

f(1.0); f(2.0); f(3.0)
print(global_counter)  # 1 (3 아님!)

해결: state 를 함수 인자로 노출. def f(x, counter): return x*2, counter+1.

4. NonHashableStaticArgumentsError

TypeError: unhashable type: 'list'

의미: static_argnames 로 표시된 인자가 hashable 이어야 함 (list 는 X, tuple 은 O).

@partial(jax.jit, static_argnames=("shape",))
def f(x, shape):
    return jnp.zeros(shape) + x

f(1.0, [3, 3])      # ❌ list
f(1.0, (3, 3))      # ✅ tuple

5. ShapeMismatchError

jax 에서 두 호출의 shape 이 달라서 recompile 됨

의미: 같은 jit 함수를 다른 shape 으로 부름 → 매번 새로 compile. 의도한 거면 OK, 아니면 static_argnames 로 분리하거나 padding.

🔍 디버깅 routine

(1) 에러 메시지의 첫 줄만 봐도 카테고리 잡힘. ConcretizationTypeError = control flow on traced value. TracerArrayConversion = wrong namespace. (2) jit 빼고 eager 모드로 돌려 봐 — 그러면 일반 Python error 처럼 동작해서 원인 찾기 쉬움. (3) jax.disable_jit() context manager 로 감싸도 됨.

실용 팁: 새 함수를 처음 짤 땐 jit 안 붙이고 eager 모드로 한 번 돌려서 정상 동작 확인 후 jit 추가. compile 에러는 까다로워서 — 한 번에 다 잡으려고 하면 시간 낭비.

Code

import jax
import jax.numpy as jnp

@jax.jit
def bad_fn(x):
    if x > 0:  # Needs concrete bool, but x is a tracer
        return x
    else:
        return -x

# ConcretizationTypeError: Abstract tracer value encountered
# where a concrete value is expected: Traced<ShapedArray(float32[])>
# This often arises when using Python control flow with traced values.

# Fix 1: Use jnp.where (no branching)
@jax.jit
def fix1(x):
    return jnp.where(x > 0, x, -x)

# Fix 2: Use jax.lax.cond (traced branching)
@jax.jit
def fix2(x):
    return jax.lax.cond(x > 0, lambda x: x, lambda x: -x, x)
import jax
import jax.numpy as jnp
import numpy as np

@jax.jit
def bad_conversion(x):
    # Trying to use a traced value as a Python int
    n = int(x.shape[0])  # This works (shape is static)
    val = int(x[0])       # ERROR: can't convert tracer to int
    return x[:val]

# Fix: keep everything in JAX-land
@jax.jit
def fixed(x):
    return jax.lax.dynamic_slice(x, (0,), (2,))
import jax
import jax.numpy as jnp

stored_values = []

@jax.jit
def leaky_fn(x):
    y = x + 1
    stored_values.append(y)  # Tracer escapes into global list!
    return y

# This may raise a leaked tracer error or produce silent bugs
# Fix: don't store traced values outside the function
import jax
import jax.numpy as jnp

def my_function(x):
    """Some complex computation."""
    intermediate = jnp.sin(x) * 2
    # ... more code ...
    return jnp.sum(intermediate)

# Step 1: Test WITHOUT jit first
x = jnp.array([1.0, 2.0, 3.0])
result = my_function(x)  # If this works, the logic is correct
print(result)

# Step 2: Use print to inspect values
print(f"intermediate shape: {jnp.sin(x).shape}")

# Step 3: Add jit and check for errors
jitted = jax.jit(my_function)
result = jitted(x)

# Step 4: If jit fails, use jax.make_jaxpr to see the traced computation
print(jax.make_jaxpr(my_function)(x))

External links

Exercise

3 개의 흔한 impurity error 의도적으로 재현: (1) global 에서 ConcretizationTypeError, (2) list-of-tracer 에서 TracerArrayConversionError, (3) print 의 'side channel' 로 인한 silent stale cache. 각각 fix + 한 줄 교훈.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.