JAX 가 제공하는 분포 — NumPy 와 거의 동일한 set, 다만 key 인자 추가.
import jax
from jax import random
key = random.PRNGKey(0)
# 기본 분포
random.uniform(key, shape, minval=0, maxval=1)
random.normal(key, shape) # 표준 정규분포
random.bernoulli(key, p=0.5, shape=())
random.exponential(key, shape)
random.gamma(key, a, shape)
random.poisson(key, lam, shape)
random.chisquare(key, df, shape)
random.t(key, df, shape)
random.beta(key, a, b, shape)
random.gumbel(key, shape)
random.laplace(key, shape)
random.logistic(key, shape)
random.cauchy(key, shape)
random.dirichlet(key, alpha, shape)
# 이산 분포
random.categorical(key, logits, shape=())
random.choice(key, a, shape, replace=True, p=None)
# permutation
random.permutation(key, x, axis=0, independent=False)
# 정수
random.randint(key, shape, minval, maxval)
random.bits(key, shape, dtype=jnp.uint32)
대부분 — NumPy 와 동일한 인자 + 첫 인자가 key.
실전 사용 예
# 1. 가중치 초기화 — Xavier/Glorot
key = random.PRNGKey(0)
fan_in, fan_out = 100, 50
limit = jnp.sqrt(6 / (fan_in + fan_out))
W = random.uniform(key, (fan_in, fan_out), minval=-limit, maxval=limit)
# 또는 He
W_he = random.normal(key, (fan_in, fan_out)) * jnp.sqrt(2 / fan_in)
# 2. Dropout
def dropout(x, key, p=0.5):
mask = random.bernoulli(key, 1 - p, x.shape)
return x * mask / (1 - p)
# 3. Sampling from softmax
logits = jnp.array([1.0, 2.0, 0.5, 3.0])
sample = random.categorical(key, logits, shape=(10,)) # (10,) 의 indices
# 4. Gumbel-softmax (differentiable sampling)
g = random.gumbel(key, logits.shape)
soft_sample = jax.nn.softmax((logits + g) / temperature)
# 5. Permutation
indices = jnp.arange(100)
shuffled = random.permutation(key, indices)
NumPy 호환성
NumPy 코드의 random call 을 옮길 때 — key 인자만 추가:
np.random.normal(0, 1, (10,)) → random.normal(key, (10,)) (loc=0, scale=1)
np.random.uniform(low=0, high=1) → random.uniform(key, ...)
np.random.choice(a, size=10) → random.choice(key, a, shape=(10,))
np.random.permutation(x) → random.permutation(key, x)
주의: NumPy 의 일부 분포는 — JAX 에선 약간 다른 이름이거나 인자 순서 다름. 새로운 분포 처음 쓸 때 docs 한 번 확인.
💡 normal 은 standard 만
JAX 의 random.normal 은 — 평균 0, std 1 만. 다른 평균 / std 면 직접 transform: mu + std * random.normal(key, shape). NumPy 처럼 normal(loc, scale) 인자 직접 받는 게 아님. 한 번 깜박하면 헷갈림.
multivariate normal 도 별도 함수 (random.multivariate_normal) — 인자 형태가 자세함. 큰 covariance 에 주의 (Cholesky 분해 비용).