C.W.K.
Stream
Lesson 01 of 06 · published

JAX 가 random 을 다르게 다루는 이유

~8 min · random, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

NumPy / PyTorch 의 random 은 — global state 기반.

import numpy as np
np.random.seed(42)
a = np.random.normal(size=10)   # 어떤 값
b = np.random.normal(size=10)   # 또 다른 값 (state 진행됨)
c = np.random.normal(size=10)   # 또 다른 값

편함. 그런데 — global state 는 JAX 의 purity 를 깸. global 읽고/쓰는 함수는 jit 안에서 한 번만 trace 되고, 그 후 같은 값 반복.

JAX 의 해법: state 를 명시적 인자로. key 라는 것을 모든 random 함수가 받음.

import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(42)
a = random.normal(key, (10,))   # 같은 key 면 항상 같은 값
b = random.normal(key, (10,))   # ← 같은 값!  (a == b)

잠깐, 두 호출이 같은 값 — 이건 의도된 거. 같은 key = 같은 random. 이게 deterministic / reproducible.

다른 random 이 필요하면 — key 를 split:

key, subkey1, subkey2 = random.split(key, 3)

a = random.normal(subkey1, (10,))   # 어떤 값
b = random.normal(subkey2, (10,))   # 다른 값
# subkey1 ≠ subkey2 → a ≠ b

# key 도 다음 split 위해 따로 보존
key, sk = random.split(key)
c = random.normal(sk, (10,))

처음엔 귀찮아 보임. 그런데 이게 주는 것:

  • 완벽한 재현성: 같은 시작 key → 절대로 같은 값. CI/CD 에서 학습이 비트 단위 일치.
  • Pure: jit 안에서 안전. 매 호출마다 trace 가 정확.
  • 병렬 친화: 다른 device / thread 가 다른 key 로 시작하면 — global state 충돌 없음.
  • 세분화된 컨트롤: 학습의 어떤 부분에 어떤 key 를 줬는지 명확. 디버깅 가능.
# PyTorch 식: random 이 어디서 흐르는지 볼 수 없음
torch.manual_seed(0)
def step():
    noise = torch.randn(...)   # 이 noise 는 어디서 왔지?

# JAX 식: 명시적
def step(key):
    noise = jax.random.normal(key, ...)   # key 가 input

🎲 Pure randomness 의 가치

처음엔 "왜 이렇게 까다롭게?" 싶지만 — 학습이 발산할 때, model 결과가 미묘하게 달라질 때, 그 원인이 random seed 인지 다른 무엇인지 구분 가능한 게 — research 에서 결정적. 한 번 익히면 — global random 이 얼마나 위험한 추상화였는지 보여.

중요한 한 가지 — JAX 의 PRNG 은 ThreeFry / RBG 알고리즘 (counter-based). 같은 seed + 같은 key 처리 → 어느 hardware 에서든 같은 결과. CPU, GPU, TPU 모두 동일.

Code

# NumPy: hidden global state — NOT functional
import numpy as np

np.random.seed(42)
print(np.random.randn())  # 0.4967...
print(np.random.randn())  # -0.1383... (different! state mutated)

# JAX: explicit state — fully functional
import jax
import jax.numpy as jnp

key = jax.random.key(42)
print(jax.random.normal(key))    # always the same value
print(jax.random.normal(key))    # exact same value again!

External links

Exercise

np.random.normal(size=10) 두 번 호출 — 값 변함 (global state). jax.random.normal(key, shape=(10,)) 두 번 같은 key — 동일. split 후 다시 — 다름. functional PRNG 모델 전체가 이 하나에.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.