C.W.K.
Stream
Lesson 05 of 05 · published

실전: Data-Parallel 학습 파이프라인

~12 min · pmap, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

지금까지 배운 것 합쳐 — 완전한 data-parallel 학습 파이프라인. 4 GPU 시뮬레이션 환경에서 작동하는 미니 trainer.

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax import random
import numpy as np

# ============ 0. mesh 만들기 ============
devices = np.array(jax.devices()).reshape(-1)
mesh = Mesh(devices, axis_names=("data",))   # 1D data-parallel mesh
n_devices = len(devices)
print(f"{n_devices} devices, mesh: {mesh}")

# ============ 1. data ============
key = random.PRNGKey(42)
key, x_key, y_key = random.split(key, 3)
N, D = 4096, 32
X = random.normal(x_key, (N, D))
true_w = random.normal(random.PRNGKey(1), (D,))
y = X @ true_w + 0.1 * random.normal(y_key, (N,))

# ============ 2. model + loss ============
def predict(params, x):
    return x @ params["w"] + params["b"]

def loss_fn(params, x, y):
    return jnp.mean((predict(params, x) - y) ** 2)

# ============ 3. train step ============
@jax.jit
def train_step(params, x, y, lr):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

# ============ 4. shard + replicate ============
def shard_data(x):
    '''첫 axis 를 device 축으로 나눔'''
    return jax.device_put(x, NamedSharding(mesh, P("data")))

def replicate_params(params):
    return jax.device_put(params, NamedSharding(mesh, P()))

# 초기화
key, init_key = random.split(key)
params = {
    "w": random.normal(init_key, (D,)) * 0.01,
    "b": jnp.zeros(()),
}
params = replicate_params(params)

# data shard
batch_size = 256   # 전체 batch
X_sharded = shard_data(X[:batch_size])
y_sharded = shard_data(y[:batch_size])

# ============ 5. 학습 ============
for step in range(500):
    params, loss = train_step(params, X_sharded, y_sharded, 0.01)
    if step % 50 == 0:
        # loss 는 자동으로 device 간 평균
        print(f"step {step:3d}: loss = {float(loss):.4f}")

# ============ 6. checkpoint 저장 ============
# 모든 device 가 같은 params 라 — 그냥 host 로 가져옴
host_params = jax.device_get(params)
# host_params["w"].shape == (D,), host_params["b"].shape == ()

print(f"\n최종 w[:5]: {host_params['w'][:5]}")
print(f"true w[:5]: {true_w[:5]}")
print(f"\n최종 b: {host_params['b']:.4f}")

이 패턴이 — production-scale 학습 코드의 출발점. 차이는 model 의 크기:

  • 지금: linear regression — 33 params
  • 실제: Llama-3 70B — 700억 params, 동일 패턴, 더 큰 mesh, FSDP sharding

다음 단계 — model parallel 추가

# 2D mesh: data 4 + model 2
devices = np.array(jax.devices()).reshape(4, 2)
mesh_2d = Mesh(devices, axis_names=("data", "model"))

# 큰 layer 를 model axis 로 sharding
W_sharded = jax.device_put(
    jnp.zeros((D, D_hidden)),
    NamedSharding(mesh_2d, P(None, "model")),
)
# data axis 는 batch, model axis 는 weight 의 column 분산

FSDP (Fully Sharded Data Parallel), pipeline parallel 등 — 같은 sharding API 의 변형. 한 번 익히면 — 모든 분산 패턴이 같은 모양.

🌐 분산 학습의 한 줄 요약

"params 를 어떻게 sharding 할지 정한다 → 각 step 안의 collective 는 jit 이 알아서 → 학습 loop 는 single-device 코드와 똑같이 보인다." 이게 jax.sharding 의 약속이고, 거의 100% 지켜져. distributed 학습이 더 이상 별도 mental model 이 아닌 시대.

실제 production checklist:

  • orbax 로 checkpoint (Track 11-5)
  • optax 로 더 좋은 optimizer (Track 11)
  • flax / equinox 로 model 추상화 (Track 10)
  • ✅ mixed precision (bfloat16 forward, float32 optimizer state)
  • ✅ gradient accumulation (effective batch size 키우기)
  • ✅ wandb 로 logging

지금 이 코드 = MVP. 위 항목 추가 = production trainer.

Code

import jax
import jax.numpy as jnp

# Model: simple 2-layer network
def init_params(key, in_dim, hidden_dim, out_dim):
    k1, k2 = jax.random.split(key)
    return {
        'w1': jax.random.normal(k1, (in_dim, hidden_dim)) * 0.1,
        'b1': jnp.zeros(hidden_dim),
        'w2': jax.random.normal(k2, (hidden_dim, out_dim)) * 0.1,
        'b2': jnp.zeros(out_dim),
    }

def forward(params, x):
    h = jnp.tanh(x @ params['w1'] + params['b1'])
    return h @ params['w2'] + params['b2']

def loss_fn(params, x, y):
    preds = forward(params, x)
    return jnp.mean((preds - y) ** 2)

# Data-parallel training step
def train_step(params, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    # Synchronize across devices
    grads = jax.lax.pmean(grads, axis_name='i')
    loss = jax.lax.pmean(loss, axis_name='i')
    # SGD update
    params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
    return params, loss

p_train_step = jax.pmap(train_step, axis_name='i')

# Initialize
n_devices = jax.device_count()
params = init_params(jax.random.PRNGKey(0), 10, 32, 1)
replicated_params = jax.tree.map(lambda x: jnp.stack([x] * n_devices), params)

# Generate data (sharded across devices)
key = jax.random.PRNGKey(1)
X = jax.random.normal(key, (n_devices * 64, 10))
y = jax.random.normal(jax.random.PRNGKey(2), (n_devices * 64, 1))
X = X.reshape(n_devices, 64, 10)
y = y.reshape(n_devices, 64, 1)

# Train
for step in range(50):
    replicated_params, losses = p_train_step(replicated_params, X, y)
    if step % 10 == 0:
        print(f"Step {step}: Loss = {losses[0]:.6f}")
import jax
import jax.numpy as jnp

# Create mesh and shardings
mesh = jax.make_mesh((jax.device_count(),), ('dp',))
data_sharding = jax.NamedSharding(mesh, jax.P('dp'))
replicated_sharding = jax.NamedSharding(mesh, jax.P())

# Initialize params (replicated across all devices)
params = init_params(jax.random.PRNGKey(0), 10, 32, 1)
params = jax.device_put(params, replicated_sharding)

# Shard data across the data-parallel axis
X = jax.random.normal(jax.random.PRNGKey(1), (256, 10))
y = jax.random.normal(jax.random.PRNGKey(2), (256, 1))
X = jax.device_put(X, data_sharding)
y = jax.device_put(y, data_sharding)

# Training step — just use jit! The sharding handles distribution.
@jax.jit
def train_step_sharded(params, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
    return params, loss

# Train — XLA handles the all-reduce automatically
for step in range(50):
    params, loss = train_step_sharded(params, X, y)
    if step % 10 == 0:
        print(f"Step {step}: Loss = {loss:.6f}")

External links

Exercise

완전한 data-parallel 예제: model + loss + optimizer + train loop, 모든 device 에 sharded. 합성 데이터로 200 step 학습. Orbax 로 final checkpoint 저장. 전체 pipeline 이 deliverable.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.