C.W.K.
Stream
Lesson 02 of 05 · published

Collective Operations: device 간 통신

~8 min · pmap, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

pmap 안에서 — 여러 device 가 각자의 결과를 합치거나 공유해야 할 때 — collective ops 사용.

가장 흔한 collective: jax.lax.psum — 모든 device 의 값을 합쳐.

def f(x):
    '''각 device 의 x 합을 모든 device 에 broadcast'''
    return jax.lax.psum(x, axis_name="i")

# pmap 으로 device axis 에 이름 부여
parallel = jax.pmap(f, axis_name="i")

# 4 device, 각각 다른 값
x = jnp.array([1.0, 2.0, 3.0, 4.0])  # device 4개
result = parallel(x)
print(result)  # [10. 10. 10. 10.] — 모든 device 가 같은 합

주요 collective ops:

  • psum(x, axis_name) — sum across devices
  • pmean(x, axis_name) — mean across devices
  • pmax(x, axis_name) — max across devices
  • pmin(x, axis_name) — min across devices
  • all_gather(x, axis_name) — 모든 device 의 x 를 concat
  • all_to_all(x, ...) — 모든 device 가 모든 device 에게 데이터 전송

data-parallel gradient averaging — 표준 use case:

def train_step(params, batch_x, batch_y):
    def loss_fn(p):
        return jnp.mean((batch_x @ p - batch_y) ** 2)

    loss, grads = jax.value_and_grad(loss_fn)(params)

    # 모든 device 의 gradient 평균
    grads = jax.lax.pmean(grads, axis_name="data")
    loss = jax.lax.pmean(loss, axis_name="data")

    new_params = params - 0.01 * grads
    return new_params, loss

parallel_step = jax.pmap(
    train_step,
    in_axes=(None, 0, 0),
    axis_name="data",
)

이게 — 4 GPU 에서 batch_size=128 학습할 때, 각 GPU 가 32 example 처리한 뒤 gradient 를 평균. 결과적으로 batch_size=128 의 학습과 동등.

여러 axis

# 2D mesh — data parallel + model parallel
parallel = jax.pmap(
    jax.pmap(f, axis_name="model"),
    axis_name="data",
)

# 두 axis 다 통합한 reduction
def f(x):
    s_data = jax.lax.psum(x, axis_name="data")
    s_model = jax.lax.psum(x, axis_name="model")
    s_all = jax.lax.psum(x, axis_name=("data", "model"))
    return s_all

큰 모델에서는 — data axis 로 batch 를 나누고, model axis 로 parameters 를 나누는 2D 패턴이 흔함.

🔄 collective 의 비용

collective op 은 — 네트워크 비용. 4 GPU 가 psum 하면 — 각 GPU 의 데이터를 다른 GPU 로 전송. 큰 model 에서 gradient (수 GB) 를 매 step psum 하는 게 — 학습 속도의 병목이 되는 일이 많아. NVIDIA NCCL, JAX collective 의 효율 차이가 학습 시간에 큰 영향.

최적화 — gradient accumulation 으로 collective 빈도 감소, mixed precision 으로 데이터 크기 감소, all_reduce 대신 reduce_scatter + all_gather 패턴 (ZeRO).

Code

import jax
import jax.numpy as jnp

# psum: sum values across all devices (all-reduce sum)
@jax.pmap
def sum_across_devices(x):
    # x is this device's local value
    total = jax.lax.psum(x, axis_name='devices')
    return total

# Note: pmap needs to know the axis name for collectives
sum_across_devices = jax.pmap(
    lambda x: jax.lax.psum(x, axis_name='i'),
    axis_name='i'
)

# pmean: average across devices
mean_fn = jax.pmap(
    lambda x: jax.lax.pmean(x, axis_name='i'),
    axis_name='i'
)

# pmax: maximum across devices
max_fn = jax.pmap(
    lambda x: jax.lax.pmax(x, axis_name='i'),
    axis_name='i'
)
import jax
import jax.numpy as jnp

def loss_fn(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

def train_step(params, x, y, lr):
    """Training step that runs on each device."""
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)

    # Average gradients across all devices
    grads = jax.lax.pmean(grads, axis_name='devices')
    loss = jax.lax.pmean(loss, axis_name='devices')

    # Update (all devices now have the same gradients → same params)
    new_params = params - lr * grads
    return new_params, loss

# Parallelize with pmap
parallel_train_step = jax.pmap(train_step, axis_name='devices')

External links

Exercise

pmap'd 함수 안에서 per-device loss 계산 후 jax.lax.psum 으로 global sum. 모든 device 의 결과 동일 확인. pmean 으로 반복.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.