it = iter(range(100))
@jax.jit
def f(x):
return x + next(it) # ❌ iterator state 가 mutate. 그것도 첫 호출 한 번.
⚠️ "조용히 틀린 답" 함정
위의 거의 모든 케이스 — JAX 는 에러 안 내. 그냥 캐시된 결과를 반환해. 학습이 안 되거나 결과가 이상하면 — 가장 먼저 의심해야 할 게 purity. dummy print 를 함수 안에 넣고 같은 인자로 두 번 호출했을 때 안 찍히면 — pure 가 아닌 거야.
방어:
# global 대신 인자로
def f(x, scale): # ✅
return x * scale
# state 는 in/out 으로
def step(state, x):
new_state = state + 1
return new_state, x * new_state
# random 은 key 로
def f(x, key):
noise = jax.random.normal(key, x.shape)
return x + noise
# print 는 jax.debug.print
@jax.jit
def f(x):
jax.debug.print("x is {x}", x=x) # ✅ runtime print
return x ** 2
이 7 가지 함정만 외우고 있어도 — 90% 의 "왜 안 돌지?" 가 사라져.
Code
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
# x[0] = 99 # TypeError: JAX arrays are immutable
# Fix: use .at[].set()
x_new = x.at[0].set(99) # Returns new array, x is unchanged
import jax
import jax.numpy as jnp
learning_rate = 0.01
# BAD: reads from closure
@jax.jit
def update_bad(params, grads):
return params - learning_rate * grads
# GOOD: pass as argument
@jax.jit
def update_good(params, grads, lr):
return params - lr * grads
# OR: use static_argnums for values that rarely change
@jax.jit
def update_static(params, grads, lr):
return params - lr * grads
# JAX will recompile when lr changes, but that's acceptable if it rarely does
import jax
import jax.numpy as jnp
@jax.jit
def fn_with_print(x):
print("This runs during TRACING only, not execution!")
y = x + 1
print(f"y = {y}") # Prints a tracer object, not a number
return y
result = fn_with_print(jnp.array(5.0))
# Output during first call:
# "This runs during TRACING only, not execution!"
# "y = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace...>"
# Second call: no print at all — JIT reuses the cached trace
result2 = fn_with_print(jnp.array(10.0))
@jax.jit
def fn_with_debug_print(x):
y = x + 1
jax.debug.print("y = {}", y) # Prints at execution time!
return y
fn_with_debug_print(jnp.array(5.0)) # Prints "y = 6.0"
fn_with_debug_print(jnp.array(10.0)) # Prints "y = 11.0"
import jax
import jax.numpy as jnp
# NumPy uses global state — IMPURE
import numpy as np
np.random.seed(42)
a = np.random.randn(3) # Mutates global RNG state
b = np.random.randn(3) # Different result — depends on hidden state
# JAX uses explicit keys — PURE
key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)
a = jax.random.normal(key1, (3,)) # Deterministic given key1
b = jax.random.normal(key2, (3,)) # Deterministic given key2
# Same key always gives same result
a_again = jax.random.normal(key1, (3,))
print(jnp.allclose(a, a_again)) # True — pure!
import jax
import jax.numpy as jnp
# PROBLEMATIC under JIT: Python if depends on a traced value
@jax.jit
def bad_relu(x):
if x > 0: # ConcretizationTypeError!
return x
else:
return 0.0
# GOOD: use JAX control flow
@jax.jit
def good_relu(x):
return jnp.where(x > 0, x, 0.0)