C.W.K.
Stream
Lesson 04 of 05 · published

흔한 jit 에러와 해결

~8 min · jit, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

실제로 만나는 jit 에러 — 5 가지가 95% 차지. 각각 패턴 + 해결:

1. Tracer 에 if

@jax.jit
def f(x):
    if x.sum() > 0:    # ❌ ConcretizationTypeError
        return x
    return -x

# Fix
@jax.jit
def f(x):
    return jnp.where(x.sum() > 0, x, -x)

2. Python loop 가 길거나 가변

@jax.jit
def f(x, n):       # n 이 traced 면 안 됨
    for _ in range(n):
        x = x + 1
    return x

# Fix 1: n 이 정해진 작은 정수
@jax.jit
def f(x):
    for _ in range(10):
        x = x + 1
    return x

# Fix 2: scan
def f(x, n):
    def step(carry, _):
        return carry + 1, None
    return jax.lax.fori_loop(0, n, lambda i, c: c + 1, x)

3. Python list / dict 의 mutation

@jax.jit
def f(x):
    out = []
    for i in range(3):
        out.append(x[i] * 2)   # OK if 길이 static
    return jnp.stack(out)

# 더 좋은 방식: 한 번에
@jax.jit
def f(x):
    return x[:3] * 2

4. shape 이 호출마다 변함

f = jax.jit(lambda x: x ** 2)

f(jnp.arange(10.))   # compile 1
f(jnp.arange(20.))   # compile 2 — 의도한 거?

# Fix: padding
def pad_to_max(x, max_len=1024):
    return jnp.pad(x, (0, max_len - x.shape[0]))

f(pad_to_max(short_x))  # 항상 같은 shape
f(pad_to_max(longer_x)) # cache hit

학습 시 — drop_last=True 로 batch size 일정하게 유지하거나, masking 으로 padding 처리.

5. static_argnames 가 hashable 이 아님

@partial(jax.jit, static_argnames=("shape",))
def f(x, shape):
    return jnp.zeros(shape) + x.sum()

f(x, [3, 3])   # ❌ list 는 unhashable
f(x, (3, 3))   # ✅

static 인자는 hashable + comparable 이어야 함 (tuple, str, int, frozenset, hashable dataclass).

🐛 디버깅 순서

(1) jit 빼고 eager 로 돌려 봐 — pure Python error 처럼 잡힘. (2) jax.make_jaxpr(f)(x) 로 trace 만 해 봐 — compile 직전까지 도달. (3) jax.config.update("jax_disable_jit", True) 로 전체 비활성화. (4) 작은 input 으로 reduce 하면서 break point 찾기.

실용 도구:

import jax
import os

# 환경 변수로 jit 끄기
os.environ["JAX_DISABLE_JIT"] = "1"

# 또는 context manager
with jax.disable_jit():
    f(x)   # eager mode

# trace 횟수 보기
@jax.jit
def f(x):
    print("trace!")  # trace 마다 한 번
    return x ** 2

f(jnp.arange(10.))   # trace!
f(jnp.arange(10.))   # (cache hit, no print)
f(jnp.arange(20.))   # trace!

처음 jit 다룰 땐 — trace count 가 의도와 다르면 그게 가장 흔한 버그 근원이야.

Code

import jax
import jax.numpy as jnp

# Problem: Python if with traced value
@jax.jit
def categorize(x):
    if x > 0:       # ERROR!
        return 1
    elif x < 0:
        return -1
    else:
        return 0

# Fix: use jnp.sign (or jnp.where for custom logic)
@jax.jit
def categorize_fixed(x):
    return jnp.sign(x).astype(jnp.int32)
import jax
import jax.numpy as jnp

# Problem: dynamic output shape
@jax.jit
def filter_positive(x):
    return x[x > 0]  # Output size depends on values — not allowed!

# Fix 1: Use a fixed-size output with padding
@jax.jit
def filter_positive_fixed(x, max_size):
    mask = x > 0
    indices = jnp.where(mask, size=max_size, fill_value=0)
    return x[indices[0]]

# Fix 2: Don't JIT this particular function
# (sometimes the simplest answer)
def filter_positive_eager(x):
    return x[x > 0]  # Works fine without jit
import jax
import jax.numpy as jnp

# Problem: Python range with traced value
@jax.jit
def repeat_add(x, n):
    result = x
    for i in range(n):  # ERROR if n is traced
        result = result + x
    return result

# Fix 1: make n static
from functools import partial

@partial(jax.jit, static_argnums=(1,))
def repeat_add_static(x, n):
    result = x
    for i in range(n):  # n is concrete now — loop is unrolled
        result = result + x
    return result

# Fix 2: use jax.lax.fori_loop
@jax.jit
def repeat_add_functional(x, n):
    return jax.lax.fori_loop(0, n, lambda i, r: r + x, x)

External links

Exercise

3 개의 minimal repro 손으로 만들기: (1) shape leak (tracer 에 .shape), (2) tracer 의 Python control flow (if x > 0:), (3) traced value 에 .item(). 각각 fix. 자기 만의 jit-debug cheatsheet 로 저장.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.