C.W.K.
Stream
Lesson 04 of 04 · published

Nested vmap 과 고급 패턴

~9 min · vmap, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

vmap 은 stack 가능 — vmap 안에 vmap, 또 안에 vmap. 다차원 batch 처리에 쓰임.

2D batching — image grid 같은 2 차원 batch:

def pixel_op(x):
    '''단일 pixel 처리 — x: scalar'''
    return jnp.tanh(x) * 2

# 2D image — 두 번 vmap
img_op = jax.vmap(jax.vmap(pixel_op))
img = jnp.zeros((28, 28))
out = img_op(img)   # (28, 28)

# 3D — image batch
batch_img_op = jax.vmap(img_op)
imgs = jnp.zeros((32, 28, 28))
out = batch_img_op(imgs)   # (32, 28, 28)

Outer product 를 vmap 으로:

def scalar_mul(a, b):
    '''단일 scalar 곱셈'''
    return a * b

# x_i * y_j 의 outer product matrix
outer = jax.vmap(jax.vmap(scalar_mul, in_axes=(None, 0)), in_axes=(0, None))
x = jnp.array([1., 2., 3.])      # (3,)
y = jnp.array([4., 5.])          # (2,)
M = outer(x, y)                  # (3, 2)
# M[i, j] = x[i] * y[j]

jnp.outer 와 같은 결과 — 차이는 어떤 함수든 일반화 가능. scalar_mul 자리에 더 복잡한 함수를 넣어도 같은 패턴.

Pairwise 거리 계산:

def euclidean(x, y):
    '''두 vector 의 euclidean 거리'''
    return jnp.sqrt(jnp.sum((x - y) ** 2))

# pairwise 거리 행렬
def pairwise(X, Y):
    return jax.vmap(
        jax.vmap(euclidean, in_axes=(None, 0)),
        in_axes=(0, None),
    )(X, Y)

X = jnp.zeros((100, 5))
Y = jnp.zeros((50, 5))
D = pairwise(X, Y)   # (100, 50) 거리 행렬

scipy 의 cdist 같은 일을 — JAX 에서, 가속기 위에서, 미분 가능한 형태로.

vmap 의 한계와 jnp.einsum

vmap nesting 이 너무 깊어지면 — XLA 가 효율적으로 컴파일하기 힘듦. 단순 tensor 연산이면 — jnp.einsum 이 더 빠를 수 있어.

# 둘 다 (B, M, N) ↔ (B, N, P) → (B, M, P) batched matmul
# 방법 1: vmap
batched_matmul_v = jax.vmap(jnp.matmul, in_axes=(0, 0))

# 방법 2: einsum
batched_matmul_e = lambda A, B: jnp.einsum("bij,bjk->bik", A, B)

# 방법 3: 그냥 jnp.matmul — broadcasting 으로 자동
batched_matmul_n = jnp.matmul

모두 같은 결과. 단순 케이스는 jnp 의 native 연산이 가장 깔끔. 임의의 함수를 batch 화 해야 할 때만 vmap.

💡 vmap 디버깅 팁

nested vmap 디버깅이 어려우면 — 단계 단계 풀어. 가장 안쪽 함수를 단일 input 으로 호출 → 결과 확인. 한 vmap 추가 → 결과 확인. 또 추가 → 확인. shape 이 의도와 다르면 거기서 멈춤. 한 번에 4 단 vmap 짜고 디버깅하지 마.

실용 가이드:

  • 1D batch — 한 vmap
  • 2D batch (예: image grid) — nested vmap 또는 reshape + 단일 vmap
  • 단순 tensor 연산 — jnp 의 native broadcasting / einsum
  • 임의 함수의 N-D batch — vmap 합성
  • 3 단 이상 nesting — 의심. einsum 으로 표현 가능한지 먼저 봐.

Code

import jax
import jax.numpy as jnp

def pairwise_distance(a, b):
    """Euclidean distance between two vectors."""
    return jnp.sqrt(jnp.sum((a - b) ** 2))

# Vectorize over pairs: distance between corresponding rows
batched_distance = jax.vmap(pairwise_distance)

A = jnp.array([[1.0, 0.0], [3.0, 0.0], [0.0, 4.0]])
B = jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])

print(batched_distance(A, B))  # [1.0, 3.0, 4.0]

# Double vmap: pairwise distance MATRIX (every a against every b)
# Outer vmap: over rows of A
# Inner vmap: over rows of B
distance_matrix = jax.vmap(jax.vmap(pairwise_distance, in_axes=(None, 0)), in_axes=(0, None))

D = distance_matrix(A, B)
print(D.shape)  # (3, 3)
print(D)
# [[1.    0.    0.   ]   distances from A[0] to each B
#  [3.    0.    0.   ]   distances from A[1] to each B
#  [4.    0.    0.   ]]  distances from A[2] to each B
import jax
import jax.numpy as jnp

def process_single(params, x):
    """Complex per-example computation."""
    h = jnp.tanh(x @ params['w1'] + params['b1'])
    return jnp.sum(h ** 2)

# Stack all transformations: vectorize, then compile
process_batch = jax.jit(jax.vmap(process_single, in_axes=(None, 0)))

# Create params as a dict (pytree)
key = jax.random.PRNGKey(0)
k1, k2 = jax.random.split(key)
params = {
    'w1': jax.random.normal(k1, (4, 8)),
    'b1': jnp.zeros(8),
}

X = jax.random.normal(k2, (128, 4))  # batch of 128

result = process_batch(params, X)
print(result.shape)  # (128,) — one scalar per example
import jax
import jax.numpy as jnp

def augment_single(key, image):
    """Random augmentation for a single image."""
    k1, k2 = jax.random.split(key)
    # Random brightness
    brightness = jax.random.uniform(k1, minval=0.8, maxval=1.2)
    # Random noise
    noise = 0.05 * jax.random.normal(k2, image.shape)
    return jnp.clip(image * brightness + noise, 0.0, 1.0)

# Vectorize over a batch of (key, image) pairs
augment_batch = jax.jit(jax.vmap(augment_single))

# Each image gets its own random key for independent augmentation
master_key = jax.random.PRNGKey(42)
batch_size = 32
keys = jax.random.split(master_key, batch_size)
images = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, 28, 28))

augmented = augment_batch(keys, images)
print(augmented.shape)  # (32, 28, 28)

External links

Exercise

f(x, y) = x * y 의 2D outer product (1D, 1D → 2D matrix) 을 nested vmap 으로 — jnp.outer 안 쓰고. size 1024×1024 의 jnp.outer 와 비교. nested vmap 의 한계 도달 시점 — jnp.einsum 으로 옮길 때.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.