C.W.K.
Stream
Lesson 01 of 05 · published

pmap 의 일: 여러 device 에서 SPMD

~8 min · pmap, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

jax.pmap = parallel map. 여러 device (GPU, TPU) 에서 같은 프로그램을 다른 데이터로 동시에 돌려.

약자: SPMD = Single Program, Multiple Data. 모든 device 가 같은 코드를 실행하지만, 입력 데이터는 device 마다 다른 slice.

import jax
import jax.numpy as jnp

# 사용 가능한 device 확인
print(jax.devices())
print(f"device 개수: {jax.device_count()}")

# 단일 device 환경에서도 시뮬레이션 가능
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
# (다음 import 부터 4 device 인 척)

기본 사용:

def f(x):
    return x ** 2 + jnp.sin(x)

# pmap — 첫 번째 axis 가 device axis (vmap 과 비슷)
parallel_f = jax.pmap(f)

# 입력 첫 axis 의 길이가 device 개수와 일치해야 함
x = jnp.arange(8).reshape(4, 2)  # 4 devices, 각 device 가 (2,) 받음
result = parallel_f(x)            # (4, 2)

각 device 가 입력의 slice 를 받아 같은 함수 실행. 결과의 첫 axis 도 device axis.

학습 step 의 SPMD 화:

def train_step(params, batch_x, batch_y):
    '''단일 device 의 train step'''
    def loss_fn(p):
        return jnp.mean((batch_x @ p - batch_y) ** 2)
    loss, grads = jax.value_and_grad(loss_fn)(params)
    new_params = params - 0.01 * grads
    return new_params, loss

# pmap — 모든 device 가 같은 step 실행
parallel_step = jax.pmap(train_step, in_axes=(None, 0, 0))

# params 는 모든 device 에 broadcast, batch 는 device 별로 나눔
# batch_x: (n_devices, B/n_devices, D)
# batch_y: (n_devices, B/n_devices)
new_params, losses = parallel_step(params, batch_x_sharded, batch_y_sharded)

중요한 한 가지 — 이대로면 각 device 의 params 가 따로 update 되어 발산. collective op 으로 gradient 를 모든 device 간 average 해야 함 — 다음 lesson.

vmap vs pmap 차이:

  • vmap: 단일 device, batch axis 자동화. 메모리 / op flow 는 한 device 안에서.
  • pmap: 여러 device, 각 device 가 독립적인 메모리. communication 은 명시적 collective.
vmap:  하나의 array (B, D) → 하나의 device 가 배치 처리
pmap:  여러 array, 각각 (B', D) → N 개 device 가 각자 처리, 필요시 collective

🌐 pmap = MPI 의 ML 버전

SPMD 는 HPC 에서 오래된 패턴. MPI 프로그래밍 — 모든 노드가 같은 binary 실행, 차이는 input 데이터, 명시적 message passing 으로 sync. pmap 은 그 모델을 ML 에 가져온 것. data-parallel 학습이 SPMD 의 정석 application.

현대 trends — JAX 팀은 jax.sharding + Mesh 로 pmap 을 점진적으로 대체하는 중. pmap 은 단순 data-parallel 에 강하지만, model parallel 같은 복잡 패턴에선 sharding API 가 더 깨끗. Track 7-4 에서 다룸.

Code

import jax
import jax.numpy as jnp

# Check available devices
print(jax.devices())          # e.g., [CudaDevice(id=0), CudaDevice(id=1)]
print(jax.device_count())     # e.g., 2

# A simple function
def square(x):
    return x ** 2

# pmap: run on all devices in parallel
parallel_square = jax.pmap(square)

# Input must have leading dimension = number of devices
n_devices = jax.device_count()
x = jnp.arange(n_devices * 4).reshape(n_devices, 4)
print(f"Input shape: {x.shape}")   # (n_devices, 4)

# Each device gets one slice along axis 0
result = parallel_square(x)
print(f"Output shape: {result.shape}")  # (n_devices, 4)
print(result)
import jax
import jax.numpy as jnp

# Preparing data for pmap: reshape to (n_devices, per_device_batch, ...)
n_devices = jax.device_count()
total_batch = 128  # Total examples
per_device = total_batch // n_devices

# Original data: (128, 784)
data = jax.random.normal(jax.random.PRNGKey(0), (total_batch, 784))

# Reshape for pmap: (n_devices, per_device_batch, 784)
sharded_data = data.reshape(n_devices, per_device, 784)
print(f"Sharded shape: {sharded_data.shape}")

External links

Exercise

jax.devices() 실행 — 여러 개면 pmap 사용. CPU 만이면 XLA_FLAGS='--xla_force_host_platform_device_count=4' 설정 후 launch. 4 'device' 에 square 함수 pmap, output shape 에 device axis 포함 검증.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.