vmap 은 stack 가능 — vmap 안에 vmap, 또 안에 vmap. 다차원 batch 처리에 쓰임.
2D batching — image grid 같은 2 차원 batch:
def pixel_op(x):
'''단일 pixel 처리 — x: scalar'''
return jnp.tanh(x) * 2
# 2D image — 두 번 vmap
img_op = jax.vmap(jax.vmap(pixel_op))
img = jnp.zeros((28, 28))
out = img_op(img) # (28, 28)
# 3D — image batch
batch_img_op = jax.vmap(img_op)
imgs = jnp.zeros((32, 28, 28))
out = batch_img_op(imgs) # (32, 28, 28)
Outer product 를 vmap 으로:
def scalar_mul(a, b):
'''단일 scalar 곱셈'''
return a * b
# x_i * y_j 의 outer product matrix
outer = jax.vmap(jax.vmap(scalar_mul, in_axes=(None, 0)), in_axes=(0, None))
x = jnp.array([1., 2., 3.]) # (3,)
y = jnp.array([4., 5.]) # (2,)
M = outer(x, y) # (3, 2)
# M[i, j] = x[i] * y[j]
jnp.outer 와 같은 결과 — 차이는 어떤 함수든 일반화 가능. scalar_mul 자리에 더 복잡한 함수를 넣어도 같은 패턴.
Pairwise 거리 계산:
def euclidean(x, y):
'''두 vector 의 euclidean 거리'''
return jnp.sqrt(jnp.sum((x - y) ** 2))
# pairwise 거리 행렬
def pairwise(X, Y):
return jax.vmap(
jax.vmap(euclidean, in_axes=(None, 0)),
in_axes=(0, None),
)(X, Y)
X = jnp.zeros((100, 5))
Y = jnp.zeros((50, 5))
D = pairwise(X, Y) # (100, 50) 거리 행렬
scipy 의 cdist 같은 일을 — JAX 에서, 가속기 위에서, 미분 가능한 형태로.
vmap 의 한계와 jnp.einsum
vmap nesting 이 너무 깊어지면 — XLA 가 효율적으로 컴파일하기 힘듦. 단순 tensor 연산이면 — jnp.einsum 이 더 빠를 수 있어.
# 둘 다 (B, M, N) ↔ (B, N, P) → (B, M, P) batched matmul
# 방법 1: vmap
batched_matmul_v = jax.vmap(jnp.matmul, in_axes=(0, 0))
# 방법 2: einsum
batched_matmul_e = lambda A, B: jnp.einsum("bij,bjk->bik", A, B)
# 방법 3: 그냥 jnp.matmul — broadcasting 으로 자동
batched_matmul_n = jnp.matmul
모두 같은 결과. 단순 케이스는 jnp 의 native 연산이 가장 깔끔. 임의의 함수를 batch 화 해야 할 때만 vmap.
💡 vmap 디버깅 팁
nested vmap 디버깅이 어려우면 — 단계 단계 풀어. 가장 안쪽 함수를 단일 input 으로 호출 → 결과 확인. 한 vmap 추가 → 결과 확인. 또 추가 → 확인. shape 이 의도와 다르면 거기서 멈춤. 한 번에 4 단 vmap 짜고 디버깅하지 마.
실용 가이드:
- 1D batch — 한 vmap
- 2D batch (예: image grid) — nested vmap 또는 reshape + 단일 vmap
- 단순 tensor 연산 — jnp 의 native broadcasting / einsum
- 임의 함수의 N-D batch — vmap 합성
- 3 단 이상 nesting — 의심. einsum 으로 표현 가능한지 먼저 봐.