NumPy / PyTorch 의 random 은 — global state 기반.
import numpy as np
np.random.seed(42)
a = np.random.normal(size=10) # 어떤 값
b = np.random.normal(size=10) # 또 다른 값 (state 진행됨)
c = np.random.normal(size=10) # 또 다른 값
편함. 그런데 — global state 는 JAX 의 purity 를 깸. global 읽고/쓰는 함수는 jit 안에서 한 번만 trace 되고, 그 후 같은 값 반복.
JAX 의 해법: state 를 명시적 인자로. key 라는 것을 모든 random 함수가 받음.
import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(42)
a = random.normal(key, (10,)) # 같은 key 면 항상 같은 값
b = random.normal(key, (10,)) # ← 같은 값! (a == b)
잠깐, 두 호출이 같은 값 — 이건 의도된 거. 같은 key = 같은 random. 이게 deterministic / reproducible.
다른 random 이 필요하면 — key 를 split:
key, subkey1, subkey2 = random.split(key, 3)
a = random.normal(subkey1, (10,)) # 어떤 값
b = random.normal(subkey2, (10,)) # 다른 값
# subkey1 ≠ subkey2 → a ≠ b
# key 도 다음 split 위해 따로 보존
key, sk = random.split(key)
c = random.normal(sk, (10,))
처음엔 귀찮아 보임. 그런데 이게 주는 것:
- 완벽한 재현성: 같은 시작 key → 절대로 같은 값. CI/CD 에서 학습이 비트 단위 일치.
- Pure: jit 안에서 안전. 매 호출마다 trace 가 정확.
- 병렬 친화: 다른 device / thread 가 다른 key 로 시작하면 — global state 충돌 없음.
- 세분화된 컨트롤: 학습의 어떤 부분에 어떤 key 를 줬는지 명확. 디버깅 가능.
# PyTorch 식: random 이 어디서 흐르는지 볼 수 없음
torch.manual_seed(0)
def step():
noise = torch.randn(...) # 이 noise 는 어디서 왔지?
# JAX 식: 명시적
def step(key):
noise = jax.random.normal(key, ...) # key 가 input
🎲 Pure randomness 의 가치
처음엔 "왜 이렇게 까다롭게?" 싶지만 — 학습이 발산할 때, model 결과가 미묘하게 달라질 때, 그 원인이 random seed 인지 다른 무엇인지 구분 가능한 게 — research 에서 결정적. 한 번 익히면 — global random 이 얼마나 위험한 추상화였는지 보여.
중요한 한 가지 — JAX 의 PRNG 은 ThreeFry / RBG 알고리즘 (counter-based). 같은 seed + 같은 key 처리 → 어느 hardware 에서든 같은 결과. CPU, GPU, TPU 모두 동일.