C.W.K.
Stream
Lesson 04 of 05 · published

현대 대안: jax.sharding 과 Mesh

~9 min · pmap, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

2024+ JAX 팀이 권장하는 새로운 분산 API — jax.sharding + Mesh. pmap 보다 깨끗하고, model parallel 까지 자연스럽게 확장.

핵심 아이디어 — array 자체에 "어떻게 분산되어 있는지" 라는 sharding 정보를 부여. 그러면 jit 이 sharding 을 알아서 해석하고 collective 를 자동 삽입.

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import numpy as np

# ============ Mesh 만들기 ============
# 8 devices 를 (data=4, model=2) 2D mesh 로
devices = np.array(jax.devices()[:8]).reshape(4, 2)
mesh = Mesh(devices, axis_names=("data", "model"))

# ============ Sharded array 만들기 ============
sharding = NamedSharding(mesh, P("data", "model"))

# 큰 array 를 mesh 에 분산
x = jnp.arange(64*32).reshape(64, 32)
x_sharded = jax.device_put(x, sharding)

# 시각화
jax.debug.visualize_array_sharding(x_sharded)
# +-----+-----+
# | d0  | d4  |   ← data=0, model=0,1
# |-----+-----|
# | d1  | d5  |
# |-----+-----|
# | d2  | d6  |
# |-----+-----|
# | d3  | d7  |
# +-----+-----+

jit 이 자동으로 분산 처리

@jax.jit
def matmul(a, b):
    return a @ b

# input 을 sharded 하면 — output 도 sharded, collective 자동 삽입
a = jax.device_put(jnp.ones((1024, 512)), NamedSharding(mesh, P("data", None)))
b = jax.device_put(jnp.ones((512, 256)), NamedSharding(mesh, P(None, "model")))

c = matmul(a, b)  # output 도 (data, model) 분산

같은 함수, sharding 만 바뀌면 — data parallel, tensor parallel, FSDP 다 가능.

학습 loop — sharding 으로

def train_step(params, x, y, lr):
    def loss_fn(p):
        return jnp.mean((x @ p - y) ** 2)
    loss, grads = jax.value_and_grad(loss_fn)(params)
    new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

train_step = jax.jit(train_step)

# data parallel: params 는 replicated, batch 는 sharded
data_sharding = NamedSharding(mesh, P("data", None))
replicated = NamedSharding(mesh, P())   # 모든 device 가 같은 값

params_repl = jax.device_put(params, replicated)
x_sharded = jax.device_put(batch_x, data_sharding)
y_sharded = jax.device_put(batch_y, NamedSharding(mesh, P("data")))

# jit 이 자동으로 collective 삽입
new_params, loss = train_step(params_repl, x_sharded, y_sharded, 0.01)

pmap 대비 장점:

  • jit 호환: 평범한 jit 으로 분산 가능. pmap 의 별도 wrapping 없음.
  • 다양한 sharding: data parallel, model parallel, FSDP, pipeline 등 같은 API 로.
  • collective 자동: pmean / psum / all_gather 등을 명시 안 해도 됨 — XLA 가 자동.
  • composable: vmap / grad 와 자유롭게 합성.

🆕 새 코드는 sharding 으로

2024+ JAX 코드 — pmap 보다 sharding API. pmap 은 deprecated 는 아니지만 maintenance mode. DeepMind, Google 의 새 모델들 (Gemini 등) 이 다 sharding 기반. 이 quest 의 다음 lesson 도 sharding 패턴으로.

전환 가이드 — pmap 코드를 sharding 으로 옮길 때:

  1. jax.pmap(f, axis_name="data")jax.jit(f)
  2. params replicate → NamedSharding(mesh, P())
  3. batch shard → NamedSharding(mesh, P("data"))
  4. jax.lax.pmean(grads) → 자동 (jit 이 처리)

Code

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
# Since JAX 0.7+, you can also use: jax.P as a shortcut for PartitionSpec

# Step 1: Create a mesh — a logical arrangement of devices
devices = jax.devices()  # All available devices
mesh = jax.make_mesh((len(devices),), ('data',))

# Step 2: Define sharding with PartitionSpec
# P('data') means "shard along the 'data' axis of the mesh"
data_sharding = NamedSharding(mesh, P('data'))
replicated = NamedSharding(mesh, P())  # No sharding = replicated

# Step 3: Place data on devices
X = jax.random.normal(jax.random.PRNGKey(0), (128, 4))
X_sharded = jax.device_put(X, data_sharding)

# Check the sharding
jax.debug.visualize_array_sharding(X_sharded)
import jax
import jax.numpy as jnp

# Create a 2D mesh: (data_parallel, model_parallel)
# For example, 8 devices arranged as 4x2
mesh = jax.make_mesh((4, 2), ('dp', 'mp'))

# Data is sharded along the 'dp' axis
data_sharding = jax.NamedSharding(mesh, jax.P('dp', None))

# Model weights are sharded along the 'mp' axis
weight_sharding = jax.NamedSharding(mesh, jax.P(None, 'mp'))

# With jit + sharding constraints, JAX handles communication automatically
@jax.jit
def forward(params, x):
    return x @ params

# JAX inserts the necessary all-gathers and reduce-scatters
import jax
import jax.numpy as jnp

# Since JAX 0.7, jax.P is an alias for jax.sharding.PartitionSpec
mesh = jax.make_mesh((4, 2), ('dp', 'mp'))

# These are equivalent:
sharding1 = jax.NamedSharding(mesh, jax.P('dp', None))
sharding2 = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('dp', None))

# jax.P is more concise and is the recommended style

External links

Exercise

pmap-기반 data-parallel trainer 를 jax.sharding.NamedSharding + Mesh 로 교체. 같은 workload 가 — output sharding 을 explicit 하게. 'jit 이 sharded array 와 그냥 작동' 이 pmap ceremony 를 대체하는 거 확인.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.