지금까지 배운 것 합쳐 — 완전한 data-parallel 학습 파이프라인. 4 GPU 시뮬레이션 환경에서 작동하는 미니 trainer.
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax import random
import numpy as np
# ============ 0. mesh 만들기 ============
devices = np.array(jax.devices()).reshape(-1)
mesh = Mesh(devices, axis_names=("data",)) # 1D data-parallel mesh
n_devices = len(devices)
print(f"{n_devices} devices, mesh: {mesh}")
# ============ 1. data ============
key = random.PRNGKey(42)
key, x_key, y_key = random.split(key, 3)
N, D = 4096, 32
X = random.normal(x_key, (N, D))
true_w = random.normal(random.PRNGKey(1), (D,))
y = X @ true_w + 0.1 * random.normal(y_key, (N,))
# ============ 2. model + loss ============
def predict(params, x):
return x @ params["w"] + params["b"]
def loss_fn(params, x, y):
return jnp.mean((predict(params, x) - y) ** 2)
# ============ 3. train step ============
@jax.jit
def train_step(params, x, y, lr):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
return new_params, loss
# ============ 4. shard + replicate ============
def shard_data(x):
'''첫 axis 를 device 축으로 나눔'''
return jax.device_put(x, NamedSharding(mesh, P("data")))
def replicate_params(params):
return jax.device_put(params, NamedSharding(mesh, P()))
# 초기화
key, init_key = random.split(key)
params = {
"w": random.normal(init_key, (D,)) * 0.01,
"b": jnp.zeros(()),
}
params = replicate_params(params)
# data shard
batch_size = 256 # 전체 batch
X_sharded = shard_data(X[:batch_size])
y_sharded = shard_data(y[:batch_size])
# ============ 5. 학습 ============
for step in range(500):
params, loss = train_step(params, X_sharded, y_sharded, 0.01)
if step % 50 == 0:
# loss 는 자동으로 device 간 평균
print(f"step {step:3d}: loss = {float(loss):.4f}")
# ============ 6. checkpoint 저장 ============
# 모든 device 가 같은 params 라 — 그냥 host 로 가져옴
host_params = jax.device_get(params)
# host_params["w"].shape == (D,), host_params["b"].shape == ()
print(f"\n최종 w[:5]: {host_params['w'][:5]}")
print(f"true w[:5]: {true_w[:5]}")
print(f"\n최종 b: {host_params['b']:.4f}")
이 패턴이 — production-scale 학습 코드의 출발점. 차이는 model 의 크기:
- 지금: linear regression — 33 params
- 실제: Llama-3 70B — 700억 params, 동일 패턴, 더 큰 mesh, FSDP sharding
다음 단계 — model parallel 추가
# 2D mesh: data 4 + model 2
devices = np.array(jax.devices()).reshape(4, 2)
mesh_2d = Mesh(devices, axis_names=("data", "model"))
# 큰 layer 를 model axis 로 sharding
W_sharded = jax.device_put(
jnp.zeros((D, D_hidden)),
NamedSharding(mesh_2d, P(None, "model")),
)
# data axis 는 batch, model axis 는 weight 의 column 분산
FSDP (Fully Sharded Data Parallel), pipeline parallel 등 — 같은 sharding API 의 변형. 한 번 익히면 — 모든 분산 패턴이 같은 모양.
🌐 분산 학습의 한 줄 요약
"params 를 어떻게 sharding 할지 정한다 → 각 step 안의 collective 는 jit 이 알아서 → 학습 loop 는 single-device 코드와 똑같이 보인다." 이게 jax.sharding 의 약속이고, 거의 100% 지켜져. distributed 학습이 더 이상 별도 mental model 이 아닌 시대.
실제 production checklist:
- ✅
orbax로 checkpoint (Track 11-5) - ✅
optax로 더 좋은 optimizer (Track 11) - ✅
flax/equinox로 model 추상화 (Track 10) - ✅ mixed precision (bfloat16 forward, float32 optimizer state)
- ✅ gradient accumulation (effective batch size 키우기)
- ✅ wandb 로 logging
지금 이 코드 = MVP. 위 항목 추가 = production trainer.