C.W.K.
Stream
Lesson 01 of 05 · published

Pure 가 뭔지, JAX 가 왜 강요하는지

~9 min · purity, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 의 가장 중요한 한 가지 규칙이야 — 함수는 순수 (pure) 해야 한다. 이게 뭐냐면:

  1. 같은 input → 항상 같은 output. 결정적.
  2. side effect 없음. 외부 상태를 읽지도 쓰지도 않음. 출력은 return 값만.

예를 들어:

# PURE
def add(a, b):
    return a + b

def normalize(x):
    return (x - x.mean()) / x.std()

# IMPURE
counter = 0
def step():
    global counter
    counter += 1   # ❌ side effect
    return counter

cache = {}
def lookup(k):
    if k in cache:  # ❌ external state 읽기
        return cache[k]
    cache[k] = expensive(k)  # ❌ external state 쓰기
    return cache[k]

JAX 는 왜 이걸 강요하나 — 변환 (jit, grad, vmap, pmap) 들이 다 함수의 정의를 trace 해서 작동해. trace 란 — 함수를 한 번 실행하면서 어떤 연산이 어떤 순서로 일어나는지 기록하는 거. 이 trace 를 compile / differentiate / vectorize 해.

그런데 — 함수 안에 side effect 가 있으면? 그 side effect 는 trace 시점에 한 번만 일어남. compile 된 후 호출에선 절대 다시 안 일어나. 그래서:

import jax

call_count = 0

@jax.jit
def f(x):
    global call_count
    call_count += 1   # 이게 정확히 한 번만 실행됨 (trace 때)
    return x * 2

f(1.0)  # call_count == 1 (trace + 실행)
f(2.0)  # call_count == 1 (trace 안 함, cached)
f(3.0)  # call_count == 1 (cached)

print(call_count)  # 1, not 3

이걸 디버깅하기가 정말 어려워. "분명 함수 호출했는데 왜 counter 가 안 늘지?" — pure 가 아니라서 그래.

🌿 함수형 사고로 전환

JAX 가 Python 의 동적 / mutable 본성을 거스르는 이유 — XLA 가 미리 compile 하기 위해선 함수가 deterministic + side-effect-free 여야만 가능해. "함수 = 입력에서 출력으로의 수학적 mapping" 이라는 옛 정의로 돌아가는 거. 처음엔 답답하지만, 익으면 — 이게 더 깨끗한 모델이라는 걸 알게 돼.

state 가 정말 필요할 땐 — 명시적으로 함수 인자로 주고받음:

# PyTorch 식 (JAX 아님)
class Counter:
    def __init__(self): self.n = 0
    def step(self): self.n += 1; return self.n

# JAX 식
def step(state):
    return state + 1, state + 1

state = 0
state, output = step(state)  # 항상 explicit
state, output = step(state)

state 를 함수 input/output 으로 분리하는 패턴 — 이게 JAX 어디서나 등장해. Optax, Flax NNX, training loop 다 이 모양이야.

Code

import jax.numpy as jnp

# PURE: output depends only on input, no side effects
def pure_fn(x):
    return jnp.sum(x ** 2)

# IMPURE: depends on global state
scale = 2.0
def impure_global(x):
    return jnp.sum(x ** 2) * scale  # Reads global variable!

# IMPURE: has side effects
results = []
def impure_sideeffect(x):
    result = jnp.sum(x ** 2)
    results.append(result)  # Side effect: modifies external list!
    return result

# IMPURE: mutates input
def impure_mutation(x):
    x[0] = 0  # Side effect: modifies input! (also fails in JAX)
    return jnp.sum(x)
import jax
import jax.numpy as jnp

# Demonstration: global state is captured at trace time
multiplier = 2.0

@jax.jit
def buggy_multiply(x):
    return x * multiplier

print(buggy_multiply(jnp.array(3.0)))  # 6.0

multiplier = 10.0  # Change the global
print(buggy_multiply(jnp.array(3.0)))  # Still 6.0! JIT cached the old value
@jax.jit
def correct_multiply(x, multiplier):
    return x * multiplier

print(correct_multiply(jnp.array(3.0), 2.0))   # 6.0
print(correct_multiply(jnp.array(3.0), 10.0))  # 30.0 — correct!

External links

Exercise

3 개의 impure 함수 작성: (1) global 읽기, (2) 인자 list mutate, (3) print 호출. 각각 jit 후 다른 input 으로 두 번 호출. 무엇이 일어나는지 정리 — JAX 가 항상 에러 안 냄, stale 결과 cache 하기도. 그 미묘함이 — purity 가 중요한 이유.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.