pmap 안에서 — 여러 device 가 각자의 결과를 합치거나 공유해야 할 때 — collective ops 사용.
가장 흔한 collective: jax.lax.psum — 모든 device 의 값을 합쳐.
def f(x):
'''각 device 의 x 합을 모든 device 에 broadcast'''
return jax.lax.psum(x, axis_name="i")
# pmap 으로 device axis 에 이름 부여
parallel = jax.pmap(f, axis_name="i")
# 4 device, 각각 다른 값
x = jnp.array([1.0, 2.0, 3.0, 4.0]) # device 4개
result = parallel(x)
print(result) # [10. 10. 10. 10.] — 모든 device 가 같은 합
주요 collective ops:
psum(x, axis_name)— sum across devicespmean(x, axis_name)— mean across devicespmax(x, axis_name)— max across devicespmin(x, axis_name)— min across devicesall_gather(x, axis_name)— 모든 device 의 x 를 concatall_to_all(x, ...)— 모든 device 가 모든 device 에게 데이터 전송
data-parallel gradient averaging — 표준 use case:
def train_step(params, batch_x, batch_y):
def loss_fn(p):
return jnp.mean((batch_x @ p - batch_y) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(params)
# 모든 device 의 gradient 평균
grads = jax.lax.pmean(grads, axis_name="data")
loss = jax.lax.pmean(loss, axis_name="data")
new_params = params - 0.01 * grads
return new_params, loss
parallel_step = jax.pmap(
train_step,
in_axes=(None, 0, 0),
axis_name="data",
)
이게 — 4 GPU 에서 batch_size=128 학습할 때, 각 GPU 가 32 example 처리한 뒤 gradient 를 평균. 결과적으로 batch_size=128 의 학습과 동등.
여러 axis
# 2D mesh — data parallel + model parallel
parallel = jax.pmap(
jax.pmap(f, axis_name="model"),
axis_name="data",
)
# 두 axis 다 통합한 reduction
def f(x):
s_data = jax.lax.psum(x, axis_name="data")
s_model = jax.lax.psum(x, axis_name="model")
s_all = jax.lax.psum(x, axis_name=("data", "model"))
return s_all
큰 모델에서는 — data axis 로 batch 를 나누고, model axis 로 parameters 를 나누는 2D 패턴이 흔함.
🔄 collective 의 비용
collective op 은 — 네트워크 비용. 4 GPU 가 psum 하면 — 각 GPU 의 데이터를 다른 GPU 로 전송. 큰 model 에서 gradient (수 GB) 를 매 step psum 하는 게 — 학습 속도의 병목이 되는 일이 많아. NVIDIA NCCL, JAX collective 의 효율 차이가 학습 시간에 큰 영향.
최적화 — gradient accumulation 으로 collective 빈도 감소, mixed precision 으로 데이터 크기 감소, all_reduce 대신 reduce_scatter + all_gather 패턴 (ZeRO).