실전에서 key 를 어떻게 흘리느냐 — 학습 코드의 가독성을 좌우. 흔한 패턴 4 가지:
패턴 1: pass-through (key, value 반환)
def sample_noise(key, shape):
return random.normal(key, shape)
def model_with_dropout(params, x, key, p=0.5):
key, subkey = random.split(key)
mask = random.bernoulli(subkey, 1 - p, x.shape)
return x * mask / (1 - p), key # 새 key 도 반환
# 호출
key = random.PRNGKey(0)
y, key = model_with_dropout(params, x, key)
y2, key = model_with_dropout(params, y, key)
장점: 명시적. 단점: 모든 함수가 key 를 in/out 으로 처리해야 — verbose.
패턴 2: 미리 split, list 로 전달
def init_model(key, n_layers):
keys = random.split(key, n_layers)
return [init_layer(k) for k in keys]
장점: 깨끗. 단점: split 횟수 미리 알아야 함.
패턴 3: 함수 안에서만 split
def sample_things(key):
k1, k2, k3 = random.split(key, 3)
a = random.normal(k1, ...)
b = random.uniform(k2, ...)
c = random.bernoulli(k3, ...)
return a, b, c
# 호출 측은 한 key 만 주면 됨
key, sub = random.split(key)
a, b, c = sample_things(sub)
장점: 함수 시그니처 깨끗. 단점: 함수 안의 random 흐름이 caller 한테 안 보임.
패턴 4: fold_in 으로 step 별 key
@jax.jit
def train_step(params, x, y, base_key, step):
'''매 step 마다 다른 random — fold_in 으로 deterministic'''
step_key = random.fold_in(base_key, step)
k_dropout, k_noise = random.split(step_key)
# ... random 사용
장점: 학습 loop 에서 추가 state 불필요. step 번호만 있으면 OK. JIT 친화. 단점: fold_in 의 statistical 독립성이 split 만큼 강하진 않다는 (이론적) 우려가 있어 — 실제론 문제 안 됨.
JAX 표준 — train state 안에 key 보존
@dataclass
class TrainState:
params: Any
opt_state: Any
step: int
key: Any # ← key 도 state 의 일부
@jax.jit
def train_step(state, batch):
key, subkey = random.split(state.key)
# subkey 로 random
...
return TrainState(
params=new_params,
opt_state=new_opt_state,
step=state.step + 1,
key=key, # 다음 step 위해
)
이 패턴이 — Flax / Equinox / 표준 trainer 코드 어디서나. key 가 state 의 한 leaf.
📐 어떤 패턴 쓸지
(1) 단순 함수 — 패턴 3 (함수 안에서 split). (2) 학습 loop — 패턴 4 (fold_in) 또는 state 안에 key 보존. (3) 큰 model 초기화 — 패턴 2 (미리 list). 가장 중요한 한 가지: 같은 코드 안에서 일관된 패턴. mixing 하면 누가 어디서 random 먹었는지 추적 불가.
흔한 함정: key 재사용. 같은 key 두 번 쓰면 같은 random. 일부러 그러는 거 (test) 면 OK, 아니면 split 빠뜨린 버그.