C.W.K.
Stream
Lesson 03 of 05 · published

pmap + grad: distributed training pattern

~8 min · pmap, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

data-parallel 학습의 표준 패턴 — pmap + grad + collective. 한 번 익히면 — 어느 model 이든 같은 모양으로 확장.

import jax
import jax.numpy as jnp
from functools import partial

# ============ 1. params 를 모든 device 에 복제 ============
def replicate(params, n_devices):
    return jax.tree.map(
        lambda x: jnp.broadcast_to(x, (n_devices,) + x.shape),
        params,
    )

# ============ 2. data 를 device 별로 분할 ============
def shard(data, n_devices):
    '''(B, ...) → (n_devices, B//n_devices, ...)'''
    return data.reshape(n_devices, -1, *data.shape[1:])

# ============ 3. data-parallel train step ============
@partial(jax.pmap, axis_name="data")
def train_step(params, x, y):
    def loss_fn(p):
        pred = x @ p["w"] + p["b"]
        return jnp.mean((pred - y) ** 2)

    loss, grads = jax.value_and_grad(loss_fn)(params)

    # 핵심: 모든 device gradient 평균
    grads = jax.lax.pmean(grads, axis_name="data")
    loss = jax.lax.pmean(loss, axis_name="data")

    new_params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
    return new_params, loss

# ============ 4. 학습 loop ============
n_devices = jax.device_count()
print(f"{n_devices} devices")

# 초기 params (단일 device 에서 만든 후 replicate)
params = {"w": jnp.zeros(10), "b": jnp.zeros(())}
params = replicate(params, n_devices)

# 한 batch
batch_x = jnp.zeros((128, 10))   # B=128
batch_y = jnp.zeros(128)
batch_x_sharded = shard(batch_x, n_devices)   # (4, 32, 10)
batch_y_sharded = shard(batch_y, n_devices)   # (4, 32)

for step in range(100):
    params, loss = train_step(params, batch_x_sharded, batch_y_sharded)
    # loss 는 pmap 결과라 (n_devices,) shape — 모든 device 가 같은 값
    if step % 10 == 0:
        print(f"step {step}: loss = {loss[0]:.4f}")

핵심 포인트:

  1. params replicate: 모든 device 가 동일한 params 를 가짐. (n_devices, ...) 모양.
  2. data shard: input 의 첫 axis 가 device axis. 각 device 가 batch 의 1/n_devices.
  3. local forward + grad: 각 device 는 자기 batch 로 gradient 계산.
  4. pmean gradient: 모든 device 의 gradient 평균 → 모든 device 가 동일한 update.
  5. params 동기화 유지: 모든 device 의 params 가 항상 같음 (수학적으로).

checkpoint 저장 / 복원

pmap 후 params 는 (n_devices, ...) 모양 — 그러나 모든 device 가 같으니까 0 번째만 저장:

def get_first_device_params(params):
    '''device axis 제거'''
    return jax.tree.map(lambda x: x[0], params)

# 저장
single_params = get_first_device_params(params)
# ... pickle 또는 orbax 로 저장 ...

# 복원 후 다시 replicate
loaded_params = ...
params = replicate(loaded_params, n_devices)

📐 data-parallel scaling rule

n_devices 늘릴 때 effective batch size 도 n_devices 배 늘어남. 같은 학습 dynamics 유지하려면 learning rate 도 ~ n_devices 배 (linear scaling rule). 너무 큰 batch 면 warmup + cosine schedule 로 안정화. Track 11 에서 다룸.

이 패턴이 JAX data-parallel 의 정석. Track 7-4 의 sharding API 는 — 같은 일을 더 깨끗하게, 그리고 model parallel 까지 자연스럽게 확장 가능하게 함.

Code

import jax
import jax.numpy as jnp

# ============================================
# Model and loss (same as single-device)
# ============================================
def predict(params, x):
    w, b = params
    return x @ w + b

def loss_fn(params, x, y):
    preds = predict(params, x)
    return jnp.mean((preds - y) ** 2)

# ============================================
# Distributed training step
# ============================================
def train_step(params, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    # Synchronize gradients across devices
    grads = jax.lax.pmean(grads, axis_name='batch')
    loss = jax.lax.pmean(loss, axis_name='batch')
    # Apply update
    w, b = params
    gw, gb = grads
    lr = 0.01
    new_params = (w - lr * gw, b - lr * gb)
    return new_params, loss

# Parallelize
p_train_step = jax.pmap(train_step, axis_name='batch')

# ============================================
# Setup: replicate params, shard data
# ============================================
n_devices = jax.device_count()

# Initialize params
key = jax.random.PRNGKey(0)
k1, k2 = jax.random.split(key)
params = (
    jax.random.normal(k1, (4, 1)) * 0.1,
    jnp.zeros((1,))
)

# REPLICATE params: copy to all devices
# Shape goes from (4, 1) to (n_devices, 4, 1)
replicated_params = jax.tree.map(
    lambda x: jnp.stack([x] * n_devices),
    params
)

# SHARD data: split across devices
total_batch = 256
per_device = total_batch // n_devices
X = jax.random.normal(jax.random.PRNGKey(1), (total_batch, 4))
y = jax.random.normal(jax.random.PRNGKey(2), (total_batch, 1))

X_sharded = X.reshape(n_devices, per_device, 4)
y_sharded = y.reshape(n_devices, per_device, 1)

# ============================================
# Training loop
# ============================================
for step in range(100):
    replicated_params, losses = p_train_step(
        replicated_params, X_sharded, y_sharded
    )
    if step % 20 == 0:
        # Loss is the same on all devices; take the first
        print(f"Step {step}: Loss = {losses[0]:.6f}")

External links

Exercise

5-5 의 linear regression trainer 를 data-parallel 로: simulated 4 device 에 batch shard, train step pmap, gradient psum. 수렴이 single-device 와 일치 검증. 큰 학습으로의 출입구.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.