vmap 안에서 random 이 어떻게 동작하는지 — 매우 중요. 잘못 짜면 — batch 의 모든 example 이 같은 noise 를 받아서 학습이 깨짐.
흔한 함정
def add_noise(x, key):
return x + random.normal(key, x.shape) * 0.1
x_batch = jnp.zeros((32, 10)) # 32 examples
key = random.PRNGKey(0)
# 잘못 — 모든 example 이 같은 noise
batched = jax.vmap(add_noise, in_axes=(0, None))
y = batched(x_batch, key)
# 32 examples 다 똑같은 noise pattern. bug!
올바른 패턴 — 각 example 마다 다른 key:
def add_noise(x, key):
return x + random.normal(key, x.shape) * 0.1
# 32 개의 독립적 key
keys = random.split(key, 32)
# vmap 으로 — 각 example 이 자기 key
batched = jax.vmap(add_noise, in_axes=(0, 0))
y = batched(x_batch, keys)
# 32 examples 다 다른 noise!
학습 loop 에서
def per_example_dropout(x, key, p=0.5):
'''단일 example 의 dropout'''
mask = random.bernoulli(key, 1 - p, x.shape)
return x * mask / (1 - p)
@jax.jit
def model_step(params, x_batch, key):
'''batch 에 dropout 적용'''
keys = random.split(key, x_batch.shape[0])
apply = jax.vmap(per_example_dropout, in_axes=(0, 0))
return apply(x_batch, keys)
이 패턴 — vmap 입력의 batch axis 길이만큼 key 를 split 한 다음 vmap.
partitionable PRNG (JAX 0.4+)
새 JAX 의 PRNG 는 — array 모양으로 직접 처리 가능. random.normal(key, (32, 10)) 한 번이 — split 후 32 번 호출과 통계적으로 동등하지만 더 효율.
# 이 둘은 다른 결과지만 둘 다 OK
# 방식 1: 한 key, big shape
big = random.normal(key, (32, 10)) # 32 example × 10 feature 의 noise
# 방식 2: split, vmap
keys = random.split(key, 32)
small = jax.vmap(lambda k: random.normal(k, (10,)))(keys)
방식 1 이 더 빠르고 메모리 효율 좋음. 그러나 — 함수가 단일 example 단위로 작성되어 있고 vmap 으로 batch 처리하는 패턴에선 방식 2 가 자연스러움.
nested vmap + random
# 2D batch — 예: (n_chains, batch_size, ...)
def f(x, key):
return x + random.normal(key, x.shape)
n_chains, batch_size = 4, 32
keys_2d = random.split(key, n_chains * batch_size).reshape(n_chains, batch_size, -1)
# 두 번 vmap
batched = jax.vmap(jax.vmap(f, in_axes=(0, 0)), in_axes=(0, 0))
y = batched(x, keys_2d)
🔑 vmap 안의 random 규칙
vmap 의 각 batch element 가 독립적 random 을 보려면 — 각자에게 독립적 key 를 줘야 한다. 한 key 를 broadcast 하면 모두가 같은 random 을 봐 (가끔 그게 의도지만 거의 항상 버그). split → vmap, 또는 한 번에 큰 shape 으로 sampling.
실전 디버깅 팁: 학습이 이상하면 — 첫 번째와 마지막 example 의 batch 결과를 print. 같으면 random seed bug. 다르면 random 은 OK, 다른 곳을 봐야 함.