C.W.K.
Stream
Lesson 05 of 06 · published

vmap 과 Random: Partitionable PRNG

~11 min · random, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

vmap 안에서 random 이 어떻게 동작하는지 — 매우 중요. 잘못 짜면 — batch 의 모든 example 이 같은 noise 를 받아서 학습이 깨짐.

흔한 함정

def add_noise(x, key):
    return x + random.normal(key, x.shape) * 0.1

x_batch = jnp.zeros((32, 10))   # 32 examples
key = random.PRNGKey(0)

# 잘못 — 모든 example 이 같은 noise
batched = jax.vmap(add_noise, in_axes=(0, None))
y = batched(x_batch, key)
# 32 examples 다 똑같은 noise pattern. bug!

올바른 패턴 — 각 example 마다 다른 key:

def add_noise(x, key):
    return x + random.normal(key, x.shape) * 0.1

# 32 개의 독립적 key
keys = random.split(key, 32)

# vmap 으로 — 각 example 이 자기 key
batched = jax.vmap(add_noise, in_axes=(0, 0))
y = batched(x_batch, keys)
# 32 examples 다 다른 noise!

학습 loop 에서

def per_example_dropout(x, key, p=0.5):
    '''단일 example 의 dropout'''
    mask = random.bernoulli(key, 1 - p, x.shape)
    return x * mask / (1 - p)

@jax.jit
def model_step(params, x_batch, key):
    '''batch 에 dropout 적용'''
    keys = random.split(key, x_batch.shape[0])
    apply = jax.vmap(per_example_dropout, in_axes=(0, 0))
    return apply(x_batch, keys)

이 패턴 — vmap 입력의 batch axis 길이만큼 key 를 split 한 다음 vmap.

partitionable PRNG (JAX 0.4+)

새 JAX 의 PRNG 는 — array 모양으로 직접 처리 가능. random.normal(key, (32, 10)) 한 번이 — split 후 32 번 호출과 통계적으로 동등하지만 더 효율.

# 이 둘은 다른 결과지만 둘 다 OK
# 방식 1: 한 key, big shape
big = random.normal(key, (32, 10))   # 32 example × 10 feature 의 noise

# 방식 2: split, vmap
keys = random.split(key, 32)
small = jax.vmap(lambda k: random.normal(k, (10,)))(keys)

방식 1 이 더 빠르고 메모리 효율 좋음. 그러나 — 함수가 단일 example 단위로 작성되어 있고 vmap 으로 batch 처리하는 패턴에선 방식 2 가 자연스러움.

nested vmap + random

# 2D batch — 예: (n_chains, batch_size, ...)
def f(x, key):
    return x + random.normal(key, x.shape)

n_chains, batch_size = 4, 32
keys_2d = random.split(key, n_chains * batch_size).reshape(n_chains, batch_size, -1)

# 두 번 vmap
batched = jax.vmap(jax.vmap(f, in_axes=(0, 0)), in_axes=(0, 0))
y = batched(x, keys_2d)

🔑 vmap 안의 random 규칙

vmap 의 각 batch element 가 독립적 random 을 보려면 — 각자에게 독립적 key 를 줘야 한다. 한 key 를 broadcast 하면 모두가 같은 random 을 봐 (가끔 그게 의도지만 거의 항상 버그). split → vmap, 또는 한 번에 큰 shape 으로 sampling.

실전 디버깅 팁: 학습이 이상하면 — 첫 번째와 마지막 example 의 batch 결과를 print. 같으면 random seed bug. 다르면 random 은 OK, 다른 곳을 봐야 함.

Code

import jax
import jax.numpy as jnp

def sample_and_transform(key):
    """Generate one random sample and apply some transformation."""
    x = jax.random.normal(key, (3,))
    return jnp.sin(x) + x ** 2

# Generate 1000 independent samples using vmap
key = jax.random.key(42)
keys = jax.random.split(key, 1000)
results = jax.vmap(sample_and_transform)(keys)
print(results.shape)  # (1000, 3)
def dropout(x, key, rate=0.5):
    mask = jax.random.bernoulli(key, 1.0 - rate, x.shape)
    return jnp.where(mask, x / (1.0 - rate), 0.0)

# Apply different dropout masks to each sample in a batch
def forward_single(params, x, key):
    h = jax.nn.relu(x @ params['w1'] + params['b1'])
    k1, k2 = jax.random.split(key)
    h = dropout(h, k1, rate=0.3)
    out = h @ params['w2'] + params['b2']
    return out

# vmap over both inputs and keys
batch_forward = jax.vmap(forward_single, in_axes=(None, 0, 0))

# Each sample gets its own dropout mask
key = jax.random.key(0)
batch_keys = jax.random.split(key, 32)  # one key per sample
# predictions = batch_forward(params, batch_x, batch_keys)
# Before JAX 0.5.0: random ops could be slow under pmap/sharding
# because the PRNG wasn't designed for partitioning.

# Since JAX 0.5.0: partitionable by default!
# Random ops automatically shard across devices efficiently.
# No config changes needed.

# Important: the partitionable PRNG produces DIFFERENT values
# than the old non-partitionable version for the same seed.
# jax.random.key(42) gives different numbers in JAX 0.5+ vs 0.4.x
def augment_image(key, image):
    """Apply random augmentations to a single image."""
    k1, k2, k3 = jax.random.split(key, 3)

    # Random horizontal flip
    flip = jax.random.bernoulli(k1)
    image = jnp.where(flip, jnp.flip(image, axis=1), image)

    # Random brightness adjustment
    brightness = jax.random.uniform(k2, minval=0.8, maxval=1.2)
    image = image * brightness

    # Random crop offset (for a simple center crop with jitter)
    offset = jax.random.randint(k3, (2,), 0, 8)
    image = jax.lax.dynamic_slice(image, (*offset, 0),
                                   (224, 224, 3))
    return jnp.clip(image, 0.0, 1.0)

# Augment entire batch with independent randomness
batch_augment = jax.vmap(augment_image)
key = jax.random.key(0)
batch_keys = jax.random.split(key, 64)  # 64 images
# augmented_batch = batch_augment(batch_keys, image_batch)

External links

Exercise

vmap'd 함수 안에서 per-example noise 생성. jax.random.split 으로 batch 에 key fan-out. 각 example 이 독립 randomness 받는 거 확인. 이 패턴이 막는 bug class — batch 전체에 같은 noise — 가 실제 research 에서 모델을 silently kill.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.