C.W.K.
Stream
Lesson 02 of 04 · published

in_axes 와 out_axes: vmap 컨트롤

~8 min · vmap, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

jax.vmap 의 핵심 인자 — in_axesout_axes. 어느 인자를 batch 화하고, 어디서 batch axis 가 나올지 정함.

in_axes

각 인자가 어느 axis 로 batch 되는지:

def f(x, y):
    return x + y

# 모두 axis 0 (default)
batched = jax.vmap(f)
batched(jnp.zeros((10, 3)), jnp.zeros((10, 3)))   # → (10, 3)

# 다른 axis 지정
batched = jax.vmap(f, in_axes=(0, 1))
batched(jnp.zeros((10, 3)), jnp.zeros((3, 10)))   # → (10, 3)

# None — broadcast
batched = jax.vmap(f, in_axes=(0, None))
batched(jnp.zeros((10, 3)), jnp.array([1., 2., 3.]))  # y broadcast

pytree 인자 — pytree 모양에 맞는 in_axes:

def g(params, x):
    return x @ params["W"] + params["b"]

# params 는 broadcast, x 는 batch
batched_g = jax.vmap(g, in_axes=({"W": None, "b": None}, 0))

# 또는 단순히
batched_g = jax.vmap(g, in_axes=(None, 0))   # 전체 pytree 를 None 으로

out_axes

출력의 어느 axis 에 batch dim 이 들어갈지:

def f(x):
    return jnp.array([x, x ** 2])  # (2,)

# default — out_axes=0, batch dim 이 axis 0 에
batched = jax.vmap(f)
out = batched(jnp.arange(10.))   # shape: (10, 2)

# axis 1 에 두고 싶으면
batched = jax.vmap(f, out_axes=1)
out = batched(jnp.arange(10.))   # shape: (2, 10)

# 출력이 tuple 이면 — 각각 다른 axis
def h(x):
    return x ** 2, x ** 3

batched = jax.vmap(h, out_axes=(0, 1))
y_sq, y_cu = batched(jnp.arange(10.))
# y_sq.shape == (10,), y_cu.shape == (10,) — out_axes 는 batch dim 위치만

in_axes 의 흔한 패턴

# 1. 모델 forward — params broadcast, input batch
jax.vmap(model_apply, in_axes=(None, 0))

# 2. per-example loss — params + x + y 모두 batch
jax.vmap(loss_fn, in_axes=(None, 0, 0))

# 3. attention — Q (batch), K/V (전체)
jax.vmap(attention, in_axes=(0, None, None))

# 4. nested vmap — 2D batching
# 첫 vmap: axis 0
# 두 번째 vmap: axis 1
double_batched = jax.vmap(jax.vmap(f, in_axes=0), in_axes=0)

💡 shape 추적은 head exercise

vmap 사용할 때 — input shape 과 output shape 을 종이에 적어 보면 헷갈림이 사라져. print(jax.eval_shape(batched_f, x)) 로 함수 안 돌리고 shape 만 검증할 수도 있어. 처음엔 종이, 익으면 머릿속.

중요한 한 가지 — in_axes 의 모든 인자는 같은 batch dim 크기를 가져야 함:

batched = jax.vmap(f, in_axes=(0, 0))
batched(jnp.zeros((10, 3)), jnp.zeros((20, 3)))   # ❌ 10 != 20

다른 batch 길이를 처리해야 하면 — 두 번 vmap 하거나, padding 으로 길이 맞추거나, scan 으로 풀어.

Code

import jax
import jax.numpy as jnp

def dot_product(a, b):
    """Dot product of two vectors."""
    return jnp.sum(a * b)

# in_axes=(0, 0): vectorize over axis 0 of both a and b
batched_dot = jax.vmap(dot_product, in_axes=(0, 0))

A = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])  # (3, 2)
B = jnp.array([[0.5, 0.5], [1.0, 0.0], [0.0, 1.0]])  # (3, 2)

# Computes dot product for each pair of rows
print(batched_dot(A, B))  # [1.5, 3.0, 6.0]
import jax
import jax.numpy as jnp

def scale_and_add(x, scale, bias):
    return x * scale + bias

# Vectorize over x (axis 0), but broadcast scale and bias
batched = jax.vmap(scale_and_add, in_axes=(0, None, None))

X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
scale = jnp.array(2.0)
bias = jnp.array([0.1, 0.2])

result = batched(X, scale, bias)
print(result)
# [[ 2.1  4.2]
#  [ 6.1  8.2]
#  [10.1 12.2]]
import jax
import jax.numpy as jnp

def matrix_vector_product(matrix, vector):
    return matrix @ vector

# matrix: vectorize over axis 0 (batch of matrices)
# vector: vectorize over axis 0 (batch of vectors)
batched_mv = jax.vmap(matrix_vector_product, in_axes=(0, 0))

matrices = jax.random.normal(jax.random.PRNGKey(0), (5, 3, 4))  # 5 matrices of 3x4
vectors = jax.random.normal(jax.random.PRNGKey(1), (5, 4))       # 5 vectors of length 4

result = batched_mv(matrices, vectors)
print(result.shape)  # (5, 3) — 5 results of length 3
import jax
import jax.numpy as jnp

def my_fn(x):
    return x ** 2

# Default: batch dimension at position 0 in output
default = jax.vmap(my_fn)(jnp.ones((5, 3)))
print(default.shape)  # (5, 3) — batch dim at 0

# Put batch dimension at position 1 in output
swapped = jax.vmap(my_fn, out_axes=1)(jnp.ones((5, 3)))
print(swapped.shape)  # (3, 5) — batch dim at 1

External links

Exercise

3 개 인자 함수. 두 번째 인자만 vmap (in_axes=(None, 0, None)). 그 다음 (0, 0, None). 매번 output shape 출력. 실행 전에 output shape 예측 가능할 때까지.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.