C.W.K.
Stream
Lesson 02 of 06 · published

PRNG Key 와 Splitting

~8 min · random, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX random 의 모든 것 — PRNGKey + split. 두 함수면 끝.

import jax
from jax import random

# 시작 key 만들기
key = random.PRNGKey(0)
print(key)   # uint32 array of shape (2,) — 64-bit seed

# split — 한 key 에서 N 개의 새 key
key, subkey = random.split(key)         # 2 개 (default)
keys = random.split(key, 4)              # 4 개
key1, key2, key3 = random.split(key, 3)  # 3 개 unpack

구조 — key 는 그냥 (2,) shape uint32 array. split 의 결과도 같은 모양:

key = random.PRNGKey(42)
print(key.shape)         # (2,)

new_keys = random.split(key, 5)
print(new_keys.shape)    # (5, 2)

key 사용의 표준 패턴

관용적으로 — split 결과의 첫 번째를 "next key" 로 보존, 나머지를 사용:

key = random.PRNGKey(0)

# 단계 1
key, subkey = random.split(key)
a = random.normal(subkey, (10,))

# 단계 2
key, subkey = random.split(key)
b = random.uniform(subkey, (10,))

# 단계 3
key, subkey = random.split(key)
c = random.bernoulli(subkey, p=0.5, shape=(10,))

이 패턴이 — JAX random 코드 어디서나. key, subkey = random.split(key) 가 거의 mantra.

여러 random 이 동시에 필요할 때

key = random.PRNGKey(0)

# 한 번에 여러 split
keys = random.split(key, 5)
key, k_init, k_dropout, k_noise, k_perm = keys

# 또는 unpack
key, k1, k2, k3 = random.split(key, 4)

tree-shaped key

큰 model 이면 — params 의 각 layer 에 다른 key 필요. 한 번에 split:

key = random.PRNGKey(0)
n_layers = 5
layer_keys = random.split(key, n_layers)

for i, lk in enumerate(layer_keys):
    layer_params[i] = init_layer(lk, in_dim, out_dim)

💡 split 은 cheap

split 이 비싸지 않을까 걱정? 안 그럼. split 은 deterministic 한 hash 함수 비슷해서 — 매우 빠름. key 를 함수 인자로 자유롭게 통과시키고, 필요할 때마다 split 해도 성능 영향 거의 없음.

fold_in — key 를 정수 식별자에 따라 deterministic 하게 변형:

# key 를 step 번호와 결합
def step_key(base_key, step):
    return random.fold_in(base_key, step)

base = random.PRNGKey(0)
for step in range(100):
    key_at_step = step_key(base, step)
    # 매 step 의 key 가 deterministic + 독립적

fold_in 은 split 보다 — "정수 → key" mapping 이라는 의도가 더 명확. 학습 step 마다 key 를 만들 때 흔히 사용.

Code

import jax
import jax.numpy as jnp

# Create a starting key from a seed
key = jax.random.key(0)
print(key)  # a special key array

# WRONG: reusing the same key gives the same numbers
print(jax.random.normal(key))  # 0.18784384
print(jax.random.normal(key))  # 0.18784384 — identical!

# RIGHT: split the key to get new, independent keys
key, subkey = jax.random.split(key)
print(jax.random.normal(subkey))  # -1.2515389

key, subkey = jax.random.split(key)
print(jax.random.normal(subkey))  # -0.5841975 — different!
# Split into many keys at once
key = jax.random.key(42)
keys = jax.random.split(key, num=5)
print(keys.shape)  # (5,) — five independent keys

# Use each key for a different purpose
samples = jax.vmap(jax.random.normal)(keys)
print(samples)  # 5 independent random numbers
# Modern API (recommended)
key = jax.random.key(0)

# Legacy API (still works)
key_legacy = jax.random.PRNGKey(0)

# Both produce valid keys for all jax.random functions

External links

Exercise

key tree 작성: PRNGKey(0) 시작, 4 개로 split, 각각 또 4 개. 16 개 독립 normal sampling. pairwise correlation 이 0 에 가까운 거 검증. branching key 모델이 — JAX random 의 합성 가능성의 이유.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.