C.W.K.
Stream
Lesson 01 of 04 · published

vmap 의 일: 단일 example 에서 batch 로

~8 min · vmap, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

ML 코드 작성에서 가장 짜증나는 부분 — batch dimension 을 손으로 처리하는 거. jax.vmap 이 그걸 자동화해.

한 example 에 작동하는 함수 작성:

import jax
import jax.numpy as jnp

# 한 vector 의 dot product
def dot(x, y):
    return jnp.sum(x * y)

x = jnp.array([1., 2., 3.])
y = jnp.array([4., 5., 6.])
print(dot(x, y))   # 32.0

이제 batch 로 — 1000 개의 (x, y) pair 의 dot products 를 한 번에:

batch_x = jax.random.normal(jax.random.PRNGKey(0), (1000, 3))
batch_y = jax.random.normal(jax.random.PRNGKey(1), (1000, 3))

# 옵션 1: Python for-loop — 느림, 비-JAX 다움
results = [dot(batch_x[i], batch_y[i]) for i in range(1000)]

# 옵션 2: 함수 다시 작성 — broadcasting 신경 써야 함
def batch_dot(X, Y):
    return jnp.sum(X * Y, axis=-1)
results = batch_dot(batch_x, batch_y)

# 옵션 3: vmap — 함수 안 바꿈
batched = jax.vmap(dot)
results = batched(batch_x, batch_y)  # shape: (1000,)

옵션 3 의 매력 — 원래 함수 그대로. jax.vmap(dot) 이 새 함수를 돌려주는데, 첫 번째 axis 를 batch 차원으로 자동 추가.

같은 패턴 — 모델 forward 에 적용:

# 한 example 의 forward
def model(params, x):
    h = jnp.tanh(x @ params["W1"] + params["b1"])
    return h @ params["W2"] + params["b2"]

# 단일 example
y = model(params, jnp.zeros(784))    # x: (784,) → y: (10,)

# Batch — vmap 한 번
batched_model = jax.vmap(model, in_axes=(None, 0))
ys = batched_model(params, batch_x)   # batch_x: (32, 784) → ys: (32, 10)

in_axes=(None, 0) — params 는 broadcast (None: batch axis 없음), x 는 0 번 axis 가 batch.

vmap vs broadcasting

둘 다 — 같은 결과. 차이는 사고 방식:

  • broadcasting: "batch axis 를 미리 처리해서 모든 op 가 자동으로 맞아 떨어지길". X @ W 는 X.shape=(B, D), W.shape=(D, D') 일 때 자동.
  • vmap: "한 example 코드를 그대로 두고, JAX 한테 batch 로 풀어 달라고 부탁". 함수 정의 안 바꿈.

broadcasting 으로 안 되는 게 vmap 으로 됨:

# scan 같은 sequential 연산은 broadcasting 으로 batch 화 어려움
def cumsum_first_n(x, n):
    return jnp.sum(x[:n])

# vmap 으로 batch 화
batched = jax.vmap(cumsum_first_n, in_axes=(0, 0))
xs = jnp.zeros((10, 100))
ns = jnp.arange(10)  # 각 example 마다 다른 n
results = batched(xs, ns)

🌟 "한 example 의 정확한 함수 → batch 자동"

JAX 코드 작성 ergonomics 의 핵심. 모든 함수를 한 example 기준으로 작성. batch dim 은 vmap 에 위임. 코드가 깨끗해지고, batch dim 의 layout 변화에도 강해짐. 모든 후속 quest 의 NN 코드가 이 패턴을 따라.

Code

import jax
import jax.numpy as jnp

# A function for ONE example
def predict_single(weights, x):
    """Linear prediction for a single input vector."""
    return jnp.dot(weights, x)

weights = jnp.array([1.0, 2.0, 3.0])
single_x = jnp.array([0.5, 0.3, 0.8])

# Works for one example
print(predict_single(weights, single_x))  # 3.5

# Now handle a BATCH of examples — just use vmap!
batch_predict = jax.vmap(predict_single, in_axes=(None, 0))

batch_x = jnp.array([
    [0.5, 0.3, 0.8],
    [1.0, 0.0, 0.5],
    [0.2, 0.4, 0.6],
])

# vmap automatically applies predict_single to each row
print(batch_predict(weights, batch_x))  # [3.5, 2.5, 2.8]
# Manual batching — error-prone and clutters the code
def predict_batch_manual(weights, X):
    return X @ weights  # Had to think about dimensions

# vmap version — no dimensional reasoning needed
predict_batch_vmap = jax.vmap(predict_single, in_axes=(None, 0))

# Both produce the same result, but vmap is cleaner

External links

Exercise

single 1D vector 의 dot(x, y) 함수. 함수 안 바꾸고 1000 (x, y) pair 에 vmap. Python loop 와 검증. 두 번째 vmap 추가 — batch of batches 처리.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.