ML 코드 작성에서 가장 짜증나는 부분 — batch dimension 을 손으로 처리하는 거. jax.vmap 이 그걸 자동화해.
한 example 에 작동하는 함수 작성:
import jax
import jax.numpy as jnp
# 한 vector 의 dot product
def dot(x, y):
return jnp.sum(x * y)
x = jnp.array([1., 2., 3.])
y = jnp.array([4., 5., 6.])
print(dot(x, y)) # 32.0
이제 batch 로 — 1000 개의 (x, y) pair 의 dot products 를 한 번에:
batch_x = jax.random.normal(jax.random.PRNGKey(0), (1000, 3))
batch_y = jax.random.normal(jax.random.PRNGKey(1), (1000, 3))
# 옵션 1: Python for-loop — 느림, 비-JAX 다움
results = [dot(batch_x[i], batch_y[i]) for i in range(1000)]
# 옵션 2: 함수 다시 작성 — broadcasting 신경 써야 함
def batch_dot(X, Y):
return jnp.sum(X * Y, axis=-1)
results = batch_dot(batch_x, batch_y)
# 옵션 3: vmap — 함수 안 바꿈
batched = jax.vmap(dot)
results = batched(batch_x, batch_y) # shape: (1000,)
옵션 3 의 매력 — 원래 함수 그대로. jax.vmap(dot) 이 새 함수를 돌려주는데, 첫 번째 axis 를 batch 차원으로 자동 추가.
같은 패턴 — 모델 forward 에 적용:
# 한 example 의 forward
def model(params, x):
h = jnp.tanh(x @ params["W1"] + params["b1"])
return h @ params["W2"] + params["b2"]
# 단일 example
y = model(params, jnp.zeros(784)) # x: (784,) → y: (10,)
# Batch — vmap 한 번
batched_model = jax.vmap(model, in_axes=(None, 0))
ys = batched_model(params, batch_x) # batch_x: (32, 784) → ys: (32, 10)
in_axes=(None, 0) — params 는 broadcast (None: batch axis 없음), x 는 0 번 axis 가 batch.
vmap vs broadcasting
둘 다 — 같은 결과. 차이는 사고 방식:
- broadcasting: "batch axis 를 미리 처리해서 모든 op 가 자동으로 맞아 떨어지길".
X @ W는 X.shape=(B, D), W.shape=(D, D') 일 때 자동. - vmap: "한 example 코드를 그대로 두고, JAX 한테 batch 로 풀어 달라고 부탁". 함수 정의 안 바꿈.
broadcasting 으로 안 되는 게 vmap 으로 됨:
# scan 같은 sequential 연산은 broadcasting 으로 batch 화 어려움
def cumsum_first_n(x, n):
return jnp.sum(x[:n])
# vmap 으로 batch 화
batched = jax.vmap(cumsum_first_n, in_axes=(0, 0))
xs = jnp.zeros((10, 100))
ns = jnp.arange(10) # 각 example 마다 다른 n
results = batched(xs, ns)
🌟 "한 example 의 정확한 함수 → batch 자동"
JAX 코드 작성 ergonomics 의 핵심. 모든 함수를 한 example 기준으로 작성. batch dim 은 vmap 에 위임. 코드가 깨끗해지고, batch dim 의 layout 변화에도 강해짐. 모든 후속 quest 의 NN 코드가 이 패턴을 따라.