jax.pmap = parallel map. 여러 device (GPU, TPU) 에서 같은 프로그램을 다른 데이터로 동시에 돌려.
약자: SPMD = Single Program, Multiple Data. 모든 device 가 같은 코드를 실행하지만, 입력 데이터는 device 마다 다른 slice.
import jax
import jax.numpy as jnp
# 사용 가능한 device 확인
print(jax.devices())
print(f"device 개수: {jax.device_count()}")
# 단일 device 환경에서도 시뮬레이션 가능
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
# (다음 import 부터 4 device 인 척)
기본 사용:
def f(x):
return x ** 2 + jnp.sin(x)
# pmap — 첫 번째 axis 가 device axis (vmap 과 비슷)
parallel_f = jax.pmap(f)
# 입력 첫 axis 의 길이가 device 개수와 일치해야 함
x = jnp.arange(8).reshape(4, 2) # 4 devices, 각 device 가 (2,) 받음
result = parallel_f(x) # (4, 2)
각 device 가 입력의 slice 를 받아 같은 함수 실행. 결과의 첫 axis 도 device axis.
학습 step 의 SPMD 화:
def train_step(params, batch_x, batch_y):
'''단일 device 의 train step'''
def loss_fn(p):
return jnp.mean((batch_x @ p - batch_y) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(params)
new_params = params - 0.01 * grads
return new_params, loss
# pmap — 모든 device 가 같은 step 실행
parallel_step = jax.pmap(train_step, in_axes=(None, 0, 0))
# params 는 모든 device 에 broadcast, batch 는 device 별로 나눔
# batch_x: (n_devices, B/n_devices, D)
# batch_y: (n_devices, B/n_devices)
new_params, losses = parallel_step(params, batch_x_sharded, batch_y_sharded)
중요한 한 가지 — 이대로면 각 device 의 params 가 따로 update 되어 발산. collective op 으로 gradient 를 모든 device 간 average 해야 함 — 다음 lesson.
vmap vs pmap 차이:
- vmap: 단일 device, batch axis 자동화. 메모리 / op flow 는 한 device 안에서.
- pmap: 여러 device, 각 device 가 독립적인 메모리. communication 은 명시적 collective.
vmap: 하나의 array (B, D) → 하나의 device 가 배치 처리
pmap: 여러 array, 각각 (B', D) → N 개 device 가 각자 처리, 필요시 collective
🌐 pmap = MPI 의 ML 버전
SPMD 는 HPC 에서 오래된 패턴. MPI 프로그래밍 — 모든 노드가 같은 binary 실행, 차이는 input 데이터, 명시적 message passing 으로 sync. pmap 은 그 모델을 ML 에 가져온 것. data-parallel 학습이 SPMD 의 정석 application.
현대 trends — JAX 팀은 jax.sharding + Mesh 로 pmap 을 점진적으로 대체하는 중. pmap 은 단순 data-parallel 에 강하지만, model parallel 같은 복잡 패턴에선 sharding API 가 더 깨끗. Track 7-4 에서 다룸.