2024+ JAX 팀이 권장하는 새로운 분산 API — jax.sharding + Mesh. pmap 보다 깨끗하고, model parallel 까지 자연스럽게 확장.
핵심 아이디어 — array 자체에 "어떻게 분산되어 있는지" 라는 sharding 정보를 부여. 그러면 jit 이 sharding 을 알아서 해석하고 collective 를 자동 삽입.
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import numpy as np
# ============ Mesh 만들기 ============
# 8 devices 를 (data=4, model=2) 2D mesh 로
devices = np.array(jax.devices()[:8]).reshape(4, 2)
mesh = Mesh(devices, axis_names=("data", "model"))
# ============ Sharded array 만들기 ============
sharding = NamedSharding(mesh, P("data", "model"))
# 큰 array 를 mesh 에 분산
x = jnp.arange(64*32).reshape(64, 32)
x_sharded = jax.device_put(x, sharding)
# 시각화
jax.debug.visualize_array_sharding(x_sharded)
# +-----+-----+
# | d0 | d4 | ← data=0, model=0,1
# |-----+-----|
# | d1 | d5 |
# |-----+-----|
# | d2 | d6 |
# |-----+-----|
# | d3 | d7 |
# +-----+-----+
jit 이 자동으로 분산 처리
@jax.jit
def matmul(a, b):
return a @ b
# input 을 sharded 하면 — output 도 sharded, collective 자동 삽입
a = jax.device_put(jnp.ones((1024, 512)), NamedSharding(mesh, P("data", None)))
b = jax.device_put(jnp.ones((512, 256)), NamedSharding(mesh, P(None, "model")))
c = matmul(a, b) # output 도 (data, model) 분산
같은 함수, sharding 만 바뀌면 — data parallel, tensor parallel, FSDP 다 가능.
학습 loop — sharding 으로
def train_step(params, x, y, lr):
def loss_fn(p):
return jnp.mean((x @ p - y) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(params)
new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
return new_params, loss
train_step = jax.jit(train_step)
# data parallel: params 는 replicated, batch 는 sharded
data_sharding = NamedSharding(mesh, P("data", None))
replicated = NamedSharding(mesh, P()) # 모든 device 가 같은 값
params_repl = jax.device_put(params, replicated)
x_sharded = jax.device_put(batch_x, data_sharding)
y_sharded = jax.device_put(batch_y, NamedSharding(mesh, P("data")))
# jit 이 자동으로 collective 삽입
new_params, loss = train_step(params_repl, x_sharded, y_sharded, 0.01)
pmap 대비 장점:
- jit 호환: 평범한 jit 으로 분산 가능. pmap 의 별도 wrapping 없음.
- 다양한 sharding: data parallel, model parallel, FSDP, pipeline 등 같은 API 로.
- collective 자동: pmean / psum / all_gather 등을 명시 안 해도 됨 — XLA 가 자동.
- composable: vmap / grad 와 자유롭게 합성.
🆕 새 코드는 sharding 으로
2024+ JAX 코드 — pmap 보다 sharding API. pmap 은 deprecated 는 아니지만 maintenance mode. DeepMind, Google 의 새 모델들 (Gemini 등) 이 다 sharding 기반. 이 quest 의 다음 lesson 도 sharding 패턴으로.
전환 가이드 — pmap 코드를 sharding 으로 옮길 때:
jax.pmap(f, axis_name="data")→jax.jit(f)- params replicate →
NamedSharding(mesh, P()) - batch shard →
NamedSharding(mesh, P("data")) jax.lax.pmean(grads)→ 자동 (jit 이 처리)