jax.vmap 의 핵심 인자 — in_axes 와 out_axes. 어느 인자를 batch 화하고, 어디서 batch axis 가 나올지 정함.
in_axes
각 인자가 어느 axis 로 batch 되는지:
def f(x, y):
return x + y
# 모두 axis 0 (default)
batched = jax.vmap(f)
batched(jnp.zeros((10, 3)), jnp.zeros((10, 3))) # → (10, 3)
# 다른 axis 지정
batched = jax.vmap(f, in_axes=(0, 1))
batched(jnp.zeros((10, 3)), jnp.zeros((3, 10))) # → (10, 3)
# None — broadcast
batched = jax.vmap(f, in_axes=(0, None))
batched(jnp.zeros((10, 3)), jnp.array([1., 2., 3.])) # y broadcast
pytree 인자 — pytree 모양에 맞는 in_axes:
def g(params, x):
return x @ params["W"] + params["b"]
# params 는 broadcast, x 는 batch
batched_g = jax.vmap(g, in_axes=({"W": None, "b": None}, 0))
# 또는 단순히
batched_g = jax.vmap(g, in_axes=(None, 0)) # 전체 pytree 를 None 으로
out_axes
출력의 어느 axis 에 batch dim 이 들어갈지:
def f(x):
return jnp.array([x, x ** 2]) # (2,)
# default — out_axes=0, batch dim 이 axis 0 에
batched = jax.vmap(f)
out = batched(jnp.arange(10.)) # shape: (10, 2)
# axis 1 에 두고 싶으면
batched = jax.vmap(f, out_axes=1)
out = batched(jnp.arange(10.)) # shape: (2, 10)
# 출력이 tuple 이면 — 각각 다른 axis
def h(x):
return x ** 2, x ** 3
batched = jax.vmap(h, out_axes=(0, 1))
y_sq, y_cu = batched(jnp.arange(10.))
# y_sq.shape == (10,), y_cu.shape == (10,) — out_axes 는 batch dim 위치만
in_axes 의 흔한 패턴
# 1. 모델 forward — params broadcast, input batch
jax.vmap(model_apply, in_axes=(None, 0))
# 2. per-example loss — params + x + y 모두 batch
jax.vmap(loss_fn, in_axes=(None, 0, 0))
# 3. attention — Q (batch), K/V (전체)
jax.vmap(attention, in_axes=(0, None, None))
# 4. nested vmap — 2D batching
# 첫 vmap: axis 0
# 두 번째 vmap: axis 1
double_batched = jax.vmap(jax.vmap(f, in_axes=0), in_axes=0)
💡 shape 추적은 head exercise
vmap 사용할 때 — input shape 과 output shape 을 종이에 적어 보면 헷갈림이 사라져. print(jax.eval_shape(batched_f, x)) 로 함수 안 돌리고 shape 만 검증할 수도 있어. 처음엔 종이, 익으면 머릿속.
중요한 한 가지 — in_axes 의 모든 인자는 같은 batch dim 크기를 가져야 함:
batched = jax.vmap(f, in_axes=(0, 0))
batched(jnp.zeros((10, 3)), jnp.zeros((20, 3))) # ❌ 10 != 20
다른 batch 길이를 처리해야 하면 — 두 번 vmap 하거나, padding 으로 길이 맞추거나, scan 으로 풀어.