JAX 의 가장 중요한 한 가지 규칙이야 — 함수는 순수 (pure) 해야 한다. 이게 뭐냐면:
- 같은 input → 항상 같은 output. 결정적.
- side effect 없음. 외부 상태를 읽지도 쓰지도 않음. 출력은 return 값만.
예를 들어:
# PURE
def add(a, b):
return a + b
def normalize(x):
return (x - x.mean()) / x.std()
# IMPURE
counter = 0
def step():
global counter
counter += 1 # ❌ side effect
return counter
cache = {}
def lookup(k):
if k in cache: # ❌ external state 읽기
return cache[k]
cache[k] = expensive(k) # ❌ external state 쓰기
return cache[k]
JAX 는 왜 이걸 강요하나 — 변환 (jit, grad, vmap, pmap) 들이 다 함수의 정의를 trace 해서 작동해. trace 란 — 함수를 한 번 실행하면서 어떤 연산이 어떤 순서로 일어나는지 기록하는 거. 이 trace 를 compile / differentiate / vectorize 해.
그런데 — 함수 안에 side effect 가 있으면? 그 side effect 는 trace 시점에 한 번만 일어남. compile 된 후 호출에선 절대 다시 안 일어나. 그래서:
import jax
call_count = 0
@jax.jit
def f(x):
global call_count
call_count += 1 # 이게 정확히 한 번만 실행됨 (trace 때)
return x * 2
f(1.0) # call_count == 1 (trace + 실행)
f(2.0) # call_count == 1 (trace 안 함, cached)
f(3.0) # call_count == 1 (cached)
print(call_count) # 1, not 3
이걸 디버깅하기가 정말 어려워. "분명 함수 호출했는데 왜 counter 가 안 늘지?" — pure 가 아니라서 그래.
🌿 함수형 사고로 전환
JAX 가 Python 의 동적 / mutable 본성을 거스르는 이유 — XLA 가 미리 compile 하기 위해선 함수가 deterministic + side-effect-free 여야만 가능해. "함수 = 입력에서 출력으로의 수학적 mapping" 이라는 옛 정의로 돌아가는 거. 처음엔 답답하지만, 익으면 — 이게 더 깨끗한 모델이라는 걸 알게 돼.
state 가 정말 필요할 땐 — 명시적으로 함수 인자로 주고받음:
# PyTorch 식 (JAX 아님)
class Counter:
def __init__(self): self.n = 0
def step(self): self.n += 1; return self.n
# JAX 식
def step(state):
return state + 1, state + 1
state = 0
state, output = step(state) # 항상 explicit
state, output = step(state)
state 를 함수 input/output 으로 분리하는 패턴 — 이게 JAX 어디서나 등장해. Optax, Flax NNX, training loop 다 이 모양이야.