C.W.K.
Stream
Lesson 06 of 06 · published

흔한 실수와 재현성

~13 min · random, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX random 사용 중 자주 나는 실수와 재현 가능한 학습 코드 작성법.

실수 1: split 안 함

# BAD
key = random.PRNGKey(0)
for step in range(100):
    noise = random.normal(key, (10,))   # 매 step 같은 noise
    ...

# GOOD
key = random.PRNGKey(0)
for step in range(100):
    key, subkey = random.split(key)
    noise = random.normal(subkey, (10,))   # 매 step 다른 noise
    ...

실수 2: split 반환값 무시

# BAD
random.split(key, 2)   # 결과 안 받음 — key 는 안 변함
noise = random.normal(key, ...)

# GOOD
key, subkey = random.split(key)
noise = random.normal(subkey, ...)

(JAX array 는 immutable 이라 split 결과 안 받으면 그냥 사라짐.)

실수 3: jit 안에서 PRNGKey(0)

# BAD — jit 안에서 새 key 만들기
@jax.jit
def f(x):
    key = random.PRNGKey(0)   # 매 호출 같은 key, trace 시점 한 번만 생성
    return x + random.normal(key, x.shape)

# GOOD — key 를 인자로
@jax.jit
def f(x, key):
    return x + random.normal(key, x.shape)

jit 안의 PRNGKey 는 한 번만 trace 됨. 그러면 — 매 호출 같은 random.

실수 4: 학습 step 마다 같은 base key

# BAD
@jax.jit
def step(params, x, base_key):
    return params + random.normal(base_key, params.shape)   # 항상 같은 noise

# GOOD — fold_in 으로 step 별
@jax.jit
def step(params, x, base_key, step_idx):
    step_key = random.fold_in(base_key, step_idx)
    return params + random.normal(step_key, params.shape)

실수 5: vmap 에서 broadcast key

# BAD — 모든 example 같은 noise
batched = jax.vmap(f, in_axes=(0, None))
out = batched(x_batch, key)

# GOOD
keys = random.split(key, x_batch.shape[0])
batched = jax.vmap(f, in_axes=(0, 0))
out = batched(x_batch, keys)

재현 가능한 학습 코드

def train(seed=42, num_epochs=10):
    '''주어진 seed 로 — 같은 학습 결과 보장'''
    key = random.PRNGKey(seed)

    # 모든 random 의 시작점을 분리
    key, k_init, k_data_shuffle, k_train = random.split(key, 4)

    # model 초기화
    params = init_model(k_init)

    # data shuffle
    perm = random.permutation(k_data_shuffle, jnp.arange(N))
    X_shuffled = X[perm]

    # 학습
    for epoch in range(num_epochs):
        for batch_idx in range(num_batches):
            # step 별 key — fold_in 으로
            step = epoch * num_batches + batch_idx
            step_key = random.fold_in(k_train, step)
            params = train_step(params, batch, step_key)

    return params

# 같은 seed 로 두 번 — bit-identical 결과
p1 = train(seed=42)
p2 = train(seed=42)
print(jax.tree.map(jnp.allclose, p1, p2))   # 모두 True

🎯 재현성 체크리스트

(1) 시작 key 는 한 곳에서 명시. (2) 모든 random 사용처는 split 또는 fold_in. (3) data 순서는 deterministic permutation. (4) jit 안에서 새 PRNGKey 안 만들기. (5) 같은 seed 로 두 번 돌려서 bit-identical 확인 — CI 에서 자동화 가능. JAX 의 약속 — 위 다 따르면 어느 hardware 에서든 동일.

NumPy / PyTorch 사용자의 흔한 미스 — 각 framework 의 random seed 를 따로따로 설정하다 빠뜨림. JAX 는 단일 PRNG model 이라 — 하나만 챙기면 됨.

Code

import jax
import jax.numpy as jnp

key = jax.random.key(0)

# WRONG: same key → same values → correlated randomness
w1 = jax.random.normal(key, (100, 50))
w2 = jax.random.normal(key, (50, 10))
# w1 and w2 share the same random pattern (scaled to shape)!

# RIGHT: split first
k1, k2 = jax.random.split(key)
w1 = jax.random.normal(k1, (100, 50))
w2 = jax.random.normal(k2, (50, 10))
# Fully independent
key = jax.random.key(0)

# WRONG: same key every iteration
samples_bad = []
for i in range(5):
    samples_bad.append(jax.random.normal(key, (3,)))
# All five samples are IDENTICAL

# RIGHT: split at each iteration
samples_good = []
for i in range(5):
    key, subkey = jax.random.split(key)
    samples_good.append(jax.random.normal(subkey, (3,)))
# All five samples are different
import random
import numpy as np

# WRONG: these are not JIT-compatible
# x = random.random()           # Python random — not traced
# x = np.random.randn(3)        # NumPy random — not traced

# RIGHT: always use jax.random inside jitted functions
@jax.jit
def generate(key):
    return jax.random.normal(key, (3,))
key = jax.random.key(123)

# Run 1
a = jax.random.normal(key, (5,))

# Run 2 — same key, same result
b = jax.random.normal(key, (5,))

assert jnp.array_equal(a, b)  # True — always

# This works even under jit
@jax.jit
def sample(key):
    return jax.random.normal(key, (1000,))

c = sample(key)
d = sample(key)
assert jnp.array_equal(c, d)  # True
import jax
import jax.numpy as jnp

def create_train_state(key, input_dim, hidden_dim, output_dim):
    """Initialize model with reproducible randomness."""
    k1, k2, k3 = jax.random.split(key, 3)
    params = {
        'w1': jax.random.normal(k1, (input_dim, hidden_dim)) * 0.01,
        'b1': jnp.zeros(hidden_dim),
        'w2': jax.random.normal(k2, (hidden_dim, output_dim)) * 0.01,
        'b2': jnp.zeros(output_dim),
    }
    return params

@jax.jit
def train_step(params, batch, dropout_key):
    """Single training step with dropout."""
    x, y = batch
    k1, k2 = jax.random.split(dropout_key)

    def loss_fn(params):
        h = jax.nn.relu(x @ params['w1'] + params['b1'])
        mask = jax.random.bernoulli(k1, 0.7, h.shape)
        h = jnp.where(mask, h / 0.7, 0.0)
        logits = h @ params['w2'] + params['b2']
        return jnp.mean((logits - y) ** 2)

    loss, grads = jax.value_and_grad(loss_fn)(params)
    params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
    return params, loss

# Training loop with proper key management
key = jax.random.key(0)
key, init_key = jax.random.split(key)
params = create_train_state(init_key, 784, 256, 10)

for step in range(1000):
    key, step_key = jax.random.split(key)
    # params, loss = train_step(params, batch, step_key)

External links

Exercise

30 줄 학습 script (model + loss + 50 step). 같은 PRNGKey(0) 로 두 번 실행. bit-identical 확인. 한 줄 바꿔 jax.random.PRNGKey(int(time.time())) 로 — drift 관찰. 재현성은 — 만들어지는 거지 바라는 게 아냐.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.