JAX random 의 모든 것 — PRNGKey + split. 두 함수면 끝.
import jax
from jax import random
# 시작 key 만들기
key = random.PRNGKey(0)
print(key) # uint32 array of shape (2,) — 64-bit seed
# split — 한 key 에서 N 개의 새 key
key, subkey = random.split(key) # 2 개 (default)
keys = random.split(key, 4) # 4 개
key1, key2, key3 = random.split(key, 3) # 3 개 unpack
구조 — key 는 그냥 (2,) shape uint32 array. split 의 결과도 같은 모양:
key = random.PRNGKey(42)
print(key.shape) # (2,)
new_keys = random.split(key, 5)
print(new_keys.shape) # (5, 2)
key 사용의 표준 패턴
관용적으로 — split 결과의 첫 번째를 "next key" 로 보존, 나머지를 사용:
key = random.PRNGKey(0)
# 단계 1
key, subkey = random.split(key)
a = random.normal(subkey, (10,))
# 단계 2
key, subkey = random.split(key)
b = random.uniform(subkey, (10,))
# 단계 3
key, subkey = random.split(key)
c = random.bernoulli(subkey, p=0.5, shape=(10,))
이 패턴이 — JAX random 코드 어디서나. key, subkey = random.split(key) 가 거의 mantra.
여러 random 이 동시에 필요할 때
key = random.PRNGKey(0)
# 한 번에 여러 split
keys = random.split(key, 5)
key, k_init, k_dropout, k_noise, k_perm = keys
# 또는 unpack
key, k1, k2, k3 = random.split(key, 4)
tree-shaped key
큰 model 이면 — params 의 각 layer 에 다른 key 필요. 한 번에 split:
key = random.PRNGKey(0)
n_layers = 5
layer_keys = random.split(key, n_layers)
for i, lk in enumerate(layer_keys):
layer_params[i] = init_layer(lk, in_dim, out_dim)
💡 split 은 cheap
split 이 비싸지 않을까 걱정? 안 그럼. split 은 deterministic 한 hash 함수 비슷해서 — 매우 빠름. key 를 함수 인자로 자유롭게 통과시키고, 필요할 때마다 split 해도 성능 영향 거의 없음.
fold_in — key 를 정수 식별자에 따라 deterministic 하게 변형:
# key 를 step 번호와 결합
def step_key(base_key, step):
return random.fold_in(base_key, step)
base = random.PRNGKey(0)
for step in range(100):
key_at_step = step_key(base, step)
# 매 step 의 key 가 deterministic + 독립적
fold_in 은 split 보다 — "정수 → key" mapping 이라는 의도가 더 명확. 학습 step 마다 key 를 만들 때 흔히 사용.