C.W.K.
Stream
Lesson 04 of 06 · published

JAX 의 Random 분포

~8 min · random, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 가 제공하는 분포 — NumPy 와 거의 동일한 set, 다만 key 인자 추가.

import jax
from jax import random

key = random.PRNGKey(0)

# 기본 분포
random.uniform(key, shape, minval=0, maxval=1)
random.normal(key, shape)            # 표준 정규분포
random.bernoulli(key, p=0.5, shape=())
random.exponential(key, shape)
random.gamma(key, a, shape)
random.poisson(key, lam, shape)
random.chisquare(key, df, shape)
random.t(key, df, shape)
random.beta(key, a, b, shape)
random.gumbel(key, shape)
random.laplace(key, shape)
random.logistic(key, shape)
random.cauchy(key, shape)
random.dirichlet(key, alpha, shape)

# 이산 분포
random.categorical(key, logits, shape=())
random.choice(key, a, shape, replace=True, p=None)

# permutation
random.permutation(key, x, axis=0, independent=False)

# 정수
random.randint(key, shape, minval, maxval)
random.bits(key, shape, dtype=jnp.uint32)

대부분 — NumPy 와 동일한 인자 + 첫 인자가 key.

실전 사용 예

# 1. 가중치 초기화 — Xavier/Glorot
key = random.PRNGKey(0)
fan_in, fan_out = 100, 50
limit = jnp.sqrt(6 / (fan_in + fan_out))
W = random.uniform(key, (fan_in, fan_out), minval=-limit, maxval=limit)

# 또는 He
W_he = random.normal(key, (fan_in, fan_out)) * jnp.sqrt(2 / fan_in)

# 2. Dropout
def dropout(x, key, p=0.5):
    mask = random.bernoulli(key, 1 - p, x.shape)
    return x * mask / (1 - p)

# 3. Sampling from softmax
logits = jnp.array([1.0, 2.0, 0.5, 3.0])
sample = random.categorical(key, logits, shape=(10,))   # (10,) 의 indices

# 4. Gumbel-softmax (differentiable sampling)
g = random.gumbel(key, logits.shape)
soft_sample = jax.nn.softmax((logits + g) / temperature)

# 5. Permutation
indices = jnp.arange(100)
shuffled = random.permutation(key, indices)

NumPy 호환성

NumPy 코드의 random call 을 옮길 때 — key 인자만 추가:

np.random.normal(0, 1, (10,))      → random.normal(key, (10,))  (loc=0, scale=1)
np.random.uniform(low=0, high=1)    → random.uniform(key, ...)
np.random.choice(a, size=10)        → random.choice(key, a, shape=(10,))
np.random.permutation(x)            → random.permutation(key, x)

주의: NumPy 의 일부 분포는 — JAX 에선 약간 다른 이름이거나 인자 순서 다름. 새로운 분포 처음 쓸 때 docs 한 번 확인.

💡 normal 은 standard 만

JAX 의 random.normal 은 — 평균 0, std 1 만. 다른 평균 / std 면 직접 transform: mu + std * random.normal(key, shape). NumPy 처럼 normal(loc, scale) 인자 직접 받는 게 아님. 한 번 깜박하면 헷갈림.

multivariate normal 도 별도 함수 (random.multivariate_normal) — 인자 형태가 자세함. 큰 covariance 에 주의 (Cholesky 분해 비용).

Code

import jax
import jax.numpy as jnp

key = jax.random.key(42)
keys = jax.random.split(key, 8)

# Uniform [0, 1)
uniform = jax.random.uniform(keys[0], shape=(3,))
# [0.299, 0.784, 0.033]

# Normal (mean=0, std=1)
normal = jax.random.normal(keys[1], shape=(3,))
# [-0.272, 1.085, -0.533]

# Bernoulli (coin flips)
coins = jax.random.bernoulli(keys[2], p=0.7, shape=(5,))
# [True, True, False, True, True]

# Categorical (sample from discrete distribution)
logits = jnp.array([1.0, 2.0, 0.5])
category = jax.random.categorical(keys[3], logits, shape=(4,))
# [1, 1, 0, 1] — category 1 most likely

# Truncated normal (clipped to [lower, upper])
trunc = jax.random.truncated_normal(keys[4], lower=-2.0, upper=2.0, shape=(3,))

# Randint (random integers)
ints = jax.random.randint(keys[5], shape=(3,), minval=0, maxval=10)

# Permutation (shuffling)
shuffled = jax.random.permutation(keys[6], jnp.arange(5))

# Exponential
exp_samples = jax.random.exponential(keys[7], shape=(3,))
# NumPy                              # JAX
# np.random.randn(3, 4)             jax.random.normal(key, (3, 4))
# np.random.rand(3, 4)              jax.random.uniform(key, (3, 4))
# np.random.randint(0, 10, (3,))    jax.random.randint(key, (3,), 0, 10)
# np.random.choice(arr, size=5)     jax.random.choice(key, arr, (5,))
# np.random.shuffle(arr)            jax.random.permutation(key, arr)

# Key difference: JAX always needs a key, and never mutates input
# Xavier/Glorot initialization from scratch
def glorot_normal(key, shape):
    fan_in, fan_out = shape[-2], shape[-1]
    std = jnp.sqrt(2.0 / (fan_in + fan_out))
    return jax.random.normal(key, shape) * std

key = jax.random.key(0)
w = glorot_normal(key, (256, 128))
print(f"Mean: {w.mean():.4f}, Std: {w.std():.4f}")
# Mean: ~0.0, Std: ~0.072

External links

Exercise

같은 key 로 normal, uniform, bernoulli, categorical, gumbel 분포에서 sampling. 1000 sample 의 empirical mean + variance 와 분석값 비교. 가장 검증하기 어려운 분포 식별.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.