C.W.K.
Stream
Lesson 03 of 06 · published

Key Management 패턴

~9 min · random, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

실전에서 key 를 어떻게 흘리느냐 — 학습 코드의 가독성을 좌우. 흔한 패턴 4 가지:

패턴 1: pass-through (key, value 반환)

def sample_noise(key, shape):
    return random.normal(key, shape)

def model_with_dropout(params, x, key, p=0.5):
    key, subkey = random.split(key)
    mask = random.bernoulli(subkey, 1 - p, x.shape)
    return x * mask / (1 - p), key   # 새 key 도 반환

# 호출
key = random.PRNGKey(0)
y, key = model_with_dropout(params, x, key)
y2, key = model_with_dropout(params, y, key)

장점: 명시적. 단점: 모든 함수가 key 를 in/out 으로 처리해야 — verbose.

패턴 2: 미리 split, list 로 전달

def init_model(key, n_layers):
    keys = random.split(key, n_layers)
    return [init_layer(k) for k in keys]

장점: 깨끗. 단점: split 횟수 미리 알아야 함.

패턴 3: 함수 안에서만 split

def sample_things(key):
    k1, k2, k3 = random.split(key, 3)
    a = random.normal(k1, ...)
    b = random.uniform(k2, ...)
    c = random.bernoulli(k3, ...)
    return a, b, c

# 호출 측은 한 key 만 주면 됨
key, sub = random.split(key)
a, b, c = sample_things(sub)

장점: 함수 시그니처 깨끗. 단점: 함수 안의 random 흐름이 caller 한테 안 보임.

패턴 4: fold_in 으로 step 별 key

@jax.jit
def train_step(params, x, y, base_key, step):
    '''매 step 마다 다른 random — fold_in 으로 deterministic'''
    step_key = random.fold_in(base_key, step)
    k_dropout, k_noise = random.split(step_key)

    # ... random 사용

장점: 학습 loop 에서 추가 state 불필요. step 번호만 있으면 OK. JIT 친화. 단점: fold_in 의 statistical 독립성이 split 만큼 강하진 않다는 (이론적) 우려가 있어 — 실제론 문제 안 됨.

JAX 표준 — train state 안에 key 보존

@dataclass
class TrainState:
    params: Any
    opt_state: Any
    step: int
    key: Any   # ← key 도 state 의 일부

@jax.jit
def train_step(state, batch):
    key, subkey = random.split(state.key)
    # subkey 로 random
    ...
    return TrainState(
        params=new_params,
        opt_state=new_opt_state,
        step=state.step + 1,
        key=key,   # 다음 step 위해
    )

이 패턴이 — Flax / Equinox / 표준 trainer 코드 어디서나. key 가 state 의 한 leaf.

📐 어떤 패턴 쓸지

(1) 단순 함수 — 패턴 3 (함수 안에서 split). (2) 학습 loop — 패턴 4 (fold_in) 또는 state 안에 key 보존. (3) 큰 model 초기화 — 패턴 2 (미리 list). 가장 중요한 한 가지: 같은 코드 안에서 일관된 패턴. mixing 하면 누가 어디서 random 먹었는지 추적 불가.

흔한 함정: key 재사용. 같은 key 두 번 쓰면 같은 random. 일부러 그러는 거 (test) 면 OK, 아니면 split 빠뜨린 버그.

Code

import jax
import jax.numpy as jnp

def init_layer(key, in_dim, out_dim):
    """Initialize a single layer with split keys."""
    k1, k2 = jax.random.split(key)
    weights = jax.random.normal(k1, (in_dim, out_dim)) * 0.01
    biases = jax.random.normal(k2, (out_dim,)) * 0.01
    return {'w': weights, 'b': biases}

def init_network(key, layer_sizes):
    """Initialize a full network, splitting keys for each layer."""
    params = []
    for i in range(len(layer_sizes) - 1):
        key, subkey = jax.random.split(key)
        params.append(init_layer(subkey, layer_sizes[i], layer_sizes[i+1]))
    return params

key = jax.random.key(42)
params = init_network(key, [784, 256, 128, 10])
print(f"Layer 0 weights shape: {params[0]['w'].shape}")  # (784, 256)
# WRONG: global key — leads to reuse bugs
# global_key = jax.random.key(0)  # Don't do this!

# RIGHT: pass key through the call chain
def dropout(x, key, rate=0.5):
    """Apply dropout with an explicit key."""
    mask = jax.random.bernoulli(key, 1.0 - rate, x.shape)
    return jnp.where(mask, x / (1.0 - rate), 0.0)

def forward(params, x, key):
    """Forward pass with dropout — key passed explicitly."""
    k1, k2 = jax.random.split(key)
    h = jnp.tanh(x @ params[0]['w'] + params[0]['b'])
    h = dropout(h, k1, rate=0.3)
    h = jnp.tanh(h @ params[1]['w'] + params[1]['b'])
    h = dropout(h, k2, rate=0.3)
    return h @ params[2]['w'] + params[2]['b']
# Option A: fold_in the step number (deterministic, no key threading)
base_key = jax.random.key(0)
for step in range(1000):
    step_key = jax.random.fold_in(base_key, step)
    # step_key is unique for each step, reproducible from base_key + step

# Option B: split at each step (standard threading)
key = jax.random.key(0)
for step in range(1000):
    key, subkey = jax.random.split(key)
    # use subkey for this step

External links

Exercise

3 개 random sample 필요한 함수 refactor. 3 가지 패턴 시도: (1) 함수 안에서 split, (2) key list 받기, (3) key 1 개 받아 안에서 split. ergonomics + 정확성 trade-off 적기. cwkPippa-식 production 에 넣을 거 선택.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.