data-parallel 학습의 표준 패턴 — pmap + grad + collective. 한 번 익히면 — 어느 model 이든 같은 모양으로 확장.
import jax
import jax.numpy as jnp
from functools import partial
# ============ 1. params 를 모든 device 에 복제 ============
def replicate(params, n_devices):
return jax.tree.map(
lambda x: jnp.broadcast_to(x, (n_devices,) + x.shape),
params,
)
# ============ 2. data 를 device 별로 분할 ============
def shard(data, n_devices):
'''(B, ...) → (n_devices, B//n_devices, ...)'''
return data.reshape(n_devices, -1, *data.shape[1:])
# ============ 3. data-parallel train step ============
@partial(jax.pmap, axis_name="data")
def train_step(params, x, y):
def loss_fn(p):
pred = x @ p["w"] + p["b"]
return jnp.mean((pred - y) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(params)
# 핵심: 모든 device gradient 평균
grads = jax.lax.pmean(grads, axis_name="data")
loss = jax.lax.pmean(loss, axis_name="data")
new_params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
return new_params, loss
# ============ 4. 학습 loop ============
n_devices = jax.device_count()
print(f"{n_devices} devices")
# 초기 params (단일 device 에서 만든 후 replicate)
params = {"w": jnp.zeros(10), "b": jnp.zeros(())}
params = replicate(params, n_devices)
# 한 batch
batch_x = jnp.zeros((128, 10)) # B=128
batch_y = jnp.zeros(128)
batch_x_sharded = shard(batch_x, n_devices) # (4, 32, 10)
batch_y_sharded = shard(batch_y, n_devices) # (4, 32)
for step in range(100):
params, loss = train_step(params, batch_x_sharded, batch_y_sharded)
# loss 는 pmap 결과라 (n_devices,) shape — 모든 device 가 같은 값
if step % 10 == 0:
print(f"step {step}: loss = {loss[0]:.4f}")
핵심 포인트:
- params replicate: 모든 device 가 동일한 params 를 가짐.
(n_devices, ...)모양. - data shard: input 의 첫 axis 가 device axis. 각 device 가 batch 의 1/n_devices.
- local forward + grad: 각 device 는 자기 batch 로 gradient 계산.
- pmean gradient: 모든 device 의 gradient 평균 → 모든 device 가 동일한 update.
- params 동기화 유지: 모든 device 의 params 가 항상 같음 (수학적으로).
checkpoint 저장 / 복원
pmap 후 params 는 (n_devices, ...) 모양 — 그러나 모든 device 가 같으니까 0 번째만 저장:
def get_first_device_params(params):
'''device axis 제거'''
return jax.tree.map(lambda x: x[0], params)
# 저장
single_params = get_first_device_params(params)
# ... pickle 또는 orbax 로 저장 ...
# 복원 후 다시 replicate
loaded_params = ...
params = replicate(loaded_params, n_devices)
📐 data-parallel scaling rule
n_devices 늘릴 때 effective batch size 도 n_devices 배 늘어남. 같은 학습 dynamics 유지하려면 learning rate 도 ~ n_devices 배 (linear scaling rule). 너무 큰 batch 면 warmup + cosine schedule 로 안정화. Track 11 에서 다룸.
이 패턴이 JAX data-parallel 의 정석. Track 7-4 의 sharding API 는 — 같은 일을 더 깨끗하게, 그리고 model parallel 까지 자연스럽게 확장 가능하게 함.