C.W.K.
Stream
Lesson 02 of 05 · published

TPU 와 GPU 위의 JAX

~8 min · ecosystem, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 의 약속 — 같은 코드가 CPU, GPU, TPU 어디서나. 실전에서는 — 각 hardware 에 약간의 ergonomic 차이.

device 확인

import jax
print(jax.devices())
# [CpuDevice(id=0)]                          ← CPU only
# [GpuDevice(id=0, process_index=0), ...]    ← CUDA GPU
# [TpuDevice(id=0, ...), ...]                 ← TPU

print(jax.default_backend())   # 'cpu' / 'gpu' / 'tpu'
print(jax.device_count())      # 8 (8 TPU cores)

설치 차이

# CPU
pip install -U jax

# CUDA 12 (NVIDIA)
pip install -U "jax[cuda12]"
# 또는 specific
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# TPU (Google Cloud)
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Apple Silicon (실험적)
pip install -U jax-metal

Colab 에서

Colab — JAX TPU 가 무료. Runtime → Change runtime type → TPU 선택. 그러면:

import jax
print(jax.devices())   # 8 개 TPU core
print(jax.device_count())   # 8

memory 관리

JAX 는 — default 로 — process 시작 시 GPU memory 의 90% 를 미리 할당 (memory fragmentation 방지). PyTorch 와 같이 쓰면 충돌 가능:

export XLA_PYTHON_CLIENT_PREALLOCATE=false
# 또는 specific fraction
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.5

Python 안에서:

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import jax

multiple GPU

# 자동 — 모든 GPU 가 visible
import jax
print(jax.device_count())   # 4 (4 GPU)

# 특정 device 만 사용
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"   # GPU 0, 1 만

TPU 의 특성

  • 한 TPU pod 가 — 보통 4 또는 8 chip. 각 chip 이 multiple core.
  • matrix multiplication 에 매우 빠름 (TPU 의 systolic array).
  • 매우 큰 batch size 에 유리.
  • bfloat16 이 native 데이터 타입.
  • dynamic shape 에 약함 — shape 변화 = recompile.

GPU 의 특성

  • NVIDIA CUDA — 가장 많이 쓰는 backend.
  • more flexibility — kernel 작성 가능 (Pallas, Triton).
  • multi-process — multi-host 학습 시 NCCL 사용.
  • memory 가 TPU 보다 적음 (보통).

multi-host 학습

여러 machine — 각자 multi-GPU. jax.distributed.initialize() 로 cluster 설정:

import jax

jax.distributed.initialize(
    coordinator_address="10.0.0.1:1234",
    num_processes=4,
    process_id=process_idx,   # 각 host 에서 다른 값
)

# 이제 모든 host 의 device 가 jax.devices() 에 보임
# pmap / sharding 이 자동으로 cross-host 통신 처리

SLURM, Kubernetes 등에서 launching — 각 framework 의 best practice.

🔋 hardware 별 sweet spot

(1) 학습 / research — TPU 가 단일 작업당 cost-effective 한 경우 많음 (Google Cloud). (2) inference / deployment — GPU 가 더 친숙한 경우 많음 (NVIDIA Triton 등 generic infrastructure). (3) on-device / mobile — Apple Silicon 의 jax-metal 또는 jax2tf → TF Lite. (4) CPU 만 — 작은 모델, 학습 / inference 가능 — hardware 부족할 때 시작.

같은 코드가 — hardware 만 바꾸면 — 그대로 도는 게 JAX 의 약속. 100% 자동은 아니고 — async dispatch, memory 관리 등 — hardware 별 약간의 tuning.

Code

import jax

# Check available devices
print(jax.devices())
# [CudaDevice(id=0), CudaDevice(id=1), ...]

# JAX automatically uses GPU if available
x = jax.numpy.ones((1000, 1000))
# x is already on GPU — no .to('cuda') needed!

# Multi-GPU: use sharding (see Track 12)
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

devices = jax.devices()  # all GPUs
mesh = Mesh(devices, ('data',))

# Data parallelism: shard batch across GPUs
data_sharding = NamedSharding(mesh, P('data'))
batch = jax.device_put(x, data_sharding)

# As of JAX 0.6.0: requires CUDA 12.8+
# pip install jax[cuda12]
# On Google Cloud TPU VMs:
# pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# TPU pods: automatic multi-host setup
print(jax.device_count())  # e.g., 8 for a v4-8, 32 for a v4-32

# TPU-specific tips:
# 1. Use bfloat16 — TPUs have native bfloat16 support
x_bf16 = x.astype(jax.numpy.bfloat16)

# 2. Pad batch sizes to multiples of 128 (TPU-friendly)
# 3. Avoid scalar operations — TPUs are designed for large tensor ops
# 4. Use jax.profiler for TPU-specific profiling
# Common performance pitfalls:

# 1. UNNECESSARY RECOMPILATION
# BAD: different shapes cause recompilation
for batch in variable_size_batches:
    result = jax.jit(fn)(batch)  # recompiles every new shape!

# GOOD: pad to fixed size
max_batch_size = 256
for batch in batches:
    padded = pad_to_size(batch, max_batch_size)
    result = jax.jit(fn)(padded)  # compiled once, reused

# 2. HOST-DEVICE TRANSFER
# BAD: pulling values back to CPU in a loop
for step in range(1000):
    loss = train_step(params, batch)
    print(float(loss))  # blocks! transfers to CPU every step

# GOOD: only transfer periodically
for step in range(1000):
    loss = train_step(params, batch)
    if step % 100 == 0:
        print(float(loss))  # transfer only every 100 steps

# 3. USE jax.block_until_ready() for timing
import time
x = jax.numpy.ones((1000, 1000))
start = time.time()
y = x @ x
y.block_until_ready()  # wait for computation to finish
print(f"Time: {time.time() - start:.4f}s")

External links

Exercise

GPU (또는 Colab TPU) 가 있으면 jax.devices(). 이전 script 하나를 코드 안 바꾸고 옮김. 같은 matmul 의 CPU vs GPU vs TPU 측정. speedup 솔직히 읽기 — JAX 의 portability 가 lesson.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.