JAX random 사용 중 자주 나는 실수와 재현 가능한 학습 코드 작성법.
실수 1: split 안 함
# BAD
key = random.PRNGKey(0)
for step in range(100):
noise = random.normal(key, (10,)) # 매 step 같은 noise
...
# GOOD
key = random.PRNGKey(0)
for step in range(100):
key, subkey = random.split(key)
noise = random.normal(subkey, (10,)) # 매 step 다른 noise
...
실수 2: split 반환값 무시
# BAD
random.split(key, 2) # 결과 안 받음 — key 는 안 변함
noise = random.normal(key, ...)
# GOOD
key, subkey = random.split(key)
noise = random.normal(subkey, ...)
(JAX array 는 immutable 이라 split 결과 안 받으면 그냥 사라짐.)
실수 3: jit 안에서 PRNGKey(0)
# BAD — jit 안에서 새 key 만들기
@jax.jit
def f(x):
key = random.PRNGKey(0) # 매 호출 같은 key, trace 시점 한 번만 생성
return x + random.normal(key, x.shape)
# GOOD — key 를 인자로
@jax.jit
def f(x, key):
return x + random.normal(key, x.shape)
jit 안의 PRNGKey 는 한 번만 trace 됨. 그러면 — 매 호출 같은 random.
실수 4: 학습 step 마다 같은 base key
# BAD
@jax.jit
def step(params, x, base_key):
return params + random.normal(base_key, params.shape) # 항상 같은 noise
# GOOD — fold_in 으로 step 별
@jax.jit
def step(params, x, base_key, step_idx):
step_key = random.fold_in(base_key, step_idx)
return params + random.normal(step_key, params.shape)
실수 5: vmap 에서 broadcast key
# BAD — 모든 example 같은 noise
batched = jax.vmap(f, in_axes=(0, None))
out = batched(x_batch, key)
# GOOD
keys = random.split(key, x_batch.shape[0])
batched = jax.vmap(f, in_axes=(0, 0))
out = batched(x_batch, keys)
재현 가능한 학습 코드
def train(seed=42, num_epochs=10):
'''주어진 seed 로 — 같은 학습 결과 보장'''
key = random.PRNGKey(seed)
# 모든 random 의 시작점을 분리
key, k_init, k_data_shuffle, k_train = random.split(key, 4)
# model 초기화
params = init_model(k_init)
# data shuffle
perm = random.permutation(k_data_shuffle, jnp.arange(N))
X_shuffled = X[perm]
# 학습
for epoch in range(num_epochs):
for batch_idx in range(num_batches):
# step 별 key — fold_in 으로
step = epoch * num_batches + batch_idx
step_key = random.fold_in(k_train, step)
params = train_step(params, batch, step_key)
return params
# 같은 seed 로 두 번 — bit-identical 결과
p1 = train(seed=42)
p2 = train(seed=42)
print(jax.tree.map(jnp.allclose, p1, p2)) # 모두 True
🎯 재현성 체크리스트
(1) 시작 key 는 한 곳에서 명시. (2) 모든 random 사용처는 split 또는 fold_in. (3) data 순서는 deterministic permutation. (4) jit 안에서 새 PRNGKey 안 만들기. (5) 같은 seed 로 두 번 돌려서 bit-identical 확인 — CI 에서 자동화 가능. JAX 의 약속 — 위 다 따르면 어느 hardware 에서든 동일.
NumPy / PyTorch 사용자의 흔한 미스 — 각 framework 의 random seed 를 따로따로 설정하다 빠뜨림. JAX 는 단일 PRNG model 이라 — 하나만 챙기면 됨.