C.W.K.
Stream
Lesson 02 of 05 · published

Purity 를 깨는 흔한 패턴들

~10 min · purity, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

실수가 일어나는 정확한 자리 — 이걸 알면 디버깅이 빨라져. 함정 갤러리 (Rogues Gallery):

1. global / nonlocal 읽기

scale = 2.0

@jax.jit
def f(x):
    return x * scale  # ❌ scale 이 첫 호출 때 capture 됨, 이후 변경 무시

scale 을 9.0 으로 바꾸고 다시 호출해도 — 여전히 2.0 곱함. trace 결과가 cached.

2. list / dict 같은 mutable container mutation

history = []

@jax.jit
def f(x):
    history.append(x)  # ❌ side effect. trace 시점에 한 번 append.
    return x * 2

3. print, logging, file write

@jax.jit
def f(x):
    print(f"x is {x}")  # ❌ trace 때만 print 됨. 게다가 x 는 Tracer object!
    return x ** 2

print 가 trace 때 도는 게 디버깅엔 가끔 유용한데, "런타임 마다 print" 라고 착각하면 함정. 정말 런타임에 print 하려면 jax.debug.print('{x}', x=x).

4. random.random / np.random

import random

@jax.jit
def f(x):
    noise = random.gauss(0, 1)  # ❌ trace 때 한 번. 매 호출마다 같은 noise.
    return x + noise

JAX 는 jax.random + key 만 받아 — Track 8 에서 정식으로.

5. exception 으로 control flow

@jax.jit
def f(x):
    try:
        return jnp.log(x)
    except:  # ❌ trace 시점엔 실제 값 없음. exception 안 일어남.
        return jnp.zeros_like(x)

6. time.time(), datetime.now()

@jax.jit
def f(x):
    seed = int(time.time())  # ❌ trace 한 순간의 시간만 capture
    ...

7. iterator / generator state

it = iter(range(100))

@jax.jit
def f(x):
    return x + next(it)  # ❌ iterator state 가 mutate. 그것도 첫 호출 한 번.

⚠️ "조용히 틀린 답" 함정

위의 거의 모든 케이스 — JAX 는 에러 안 내. 그냥 캐시된 결과를 반환해. 학습이 안 되거나 결과가 이상하면 — 가장 먼저 의심해야 할 게 purity. dummy print 를 함수 안에 넣고 같은 인자로 두 번 호출했을 때 안 찍히면 — pure 가 아닌 거야.

방어:

# global 대신 인자로
def f(x, scale):  # ✅
    return x * scale

# state 는 in/out 으로
def step(state, x):
    new_state = state + 1
    return new_state, x * new_state

# random 은 key 로
def f(x, key):
    noise = jax.random.normal(key, x.shape)
    return x + noise

# print 는 jax.debug.print
@jax.jit
def f(x):
    jax.debug.print("x is {x}", x=x)  # ✅ runtime print
    return x ** 2

이 7 가지 함정만 외우고 있어도 — 90% 의 "왜 안 돌지?" 가 사라져.

Code

import jax.numpy as jnp

x = jnp.array([1, 2, 3])
# x[0] = 99  # TypeError: JAX arrays are immutable

# Fix: use .at[].set()
x_new = x.at[0].set(99)  # Returns new array, x is unchanged
import jax
import jax.numpy as jnp

learning_rate = 0.01

# BAD: reads from closure
@jax.jit
def update_bad(params, grads):
    return params - learning_rate * grads

# GOOD: pass as argument
@jax.jit
def update_good(params, grads, lr):
    return params - lr * grads

# OR: use static_argnums for values that rarely change
@jax.jit
def update_static(params, grads, lr):
    return params - lr * grads
# JAX will recompile when lr changes, but that's acceptable if it rarely does
import jax
import jax.numpy as jnp

@jax.jit
def fn_with_print(x):
    print("This runs during TRACING only, not execution!")
    y = x + 1
    print(f"y = {y}")  # Prints a tracer object, not a number
    return y

result = fn_with_print(jnp.array(5.0))
# Output during first call:
# "This runs during TRACING only, not execution!"
# "y = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace...>"

# Second call: no print at all — JIT reuses the cached trace
result2 = fn_with_print(jnp.array(10.0))
@jax.jit
def fn_with_debug_print(x):
    y = x + 1
    jax.debug.print("y = {}", y)  # Prints at execution time!
    return y

fn_with_debug_print(jnp.array(5.0))   # Prints "y = 6.0"
fn_with_debug_print(jnp.array(10.0))  # Prints "y = 11.0"
import jax
import jax.numpy as jnp

# NumPy uses global state — IMPURE
import numpy as np
np.random.seed(42)
a = np.random.randn(3)  # Mutates global RNG state
b = np.random.randn(3)  # Different result — depends on hidden state

# JAX uses explicit keys — PURE
key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)
a = jax.random.normal(key1, (3,))  # Deterministic given key1
b = jax.random.normal(key2, (3,))  # Deterministic given key2

# Same key always gives same result
a_again = jax.random.normal(key1, (3,))
print(jnp.allclose(a, a_again))  # True — pure!
import jax
import jax.numpy as jnp

# PROBLEMATIC under JIT: Python if depends on a traced value
@jax.jit
def bad_relu(x):
    if x > 0:  # ConcretizationTypeError!
        return x
    else:
        return 0.0

# GOOD: use JAX control flow
@jax.jit
def good_relu(x):
    return jnp.where(x > 0, x, 0.0)

External links

Exercise

최근 Python script 훑기. purity 위반하는 곳 모두 — global 읽기, list mutation, time/random side effect, print. 표로 정리. 먼저 refactor 할 top 3 선택.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.