JAX 의 약속 — 같은 코드가 CPU, GPU, TPU 어디서나. 실전에서는 — 각 hardware 에 약간의 ergonomic 차이.
device 확인
import jax
print(jax.devices())
# [CpuDevice(id=0)] ← CPU only
# [GpuDevice(id=0, process_index=0), ...] ← CUDA GPU
# [TpuDevice(id=0, ...), ...] ← TPU
print(jax.default_backend()) # 'cpu' / 'gpu' / 'tpu'
print(jax.device_count()) # 8 (8 TPU cores)
설치 차이
# CPU
pip install -U jax
# CUDA 12 (NVIDIA)
pip install -U "jax[cuda12]"
# 또는 specific
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# TPU (Google Cloud)
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Apple Silicon (실험적)
pip install -U jax-metal
Colab 에서
Colab — JAX TPU 가 무료. Runtime → Change runtime type → TPU 선택. 그러면:
import jax
print(jax.devices()) # 8 개 TPU core
print(jax.device_count()) # 8
memory 관리
JAX 는 — default 로 — process 시작 시 GPU memory 의 90% 를 미리 할당 (memory fragmentation 방지). PyTorch 와 같이 쓰면 충돌 가능:
export XLA_PYTHON_CLIENT_PREALLOCATE=false
# 또는 specific fraction
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.5
Python 안에서:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import jax
multiple GPU
# 자동 — 모든 GPU 가 visible
import jax
print(jax.device_count()) # 4 (4 GPU)
# 특정 device 만 사용
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" # GPU 0, 1 만
TPU 의 특성
- 한 TPU pod 가 — 보통 4 또는 8 chip. 각 chip 이 multiple core.
- matrix multiplication 에 매우 빠름 (TPU 의 systolic array).
- 매우 큰 batch size 에 유리.
- bfloat16 이 native 데이터 타입.
- dynamic shape 에 약함 — shape 변화 = recompile.
GPU 의 특성
- NVIDIA CUDA — 가장 많이 쓰는 backend.
- more flexibility — kernel 작성 가능 (Pallas, Triton).
- multi-process — multi-host 학습 시 NCCL 사용.
- memory 가 TPU 보다 적음 (보통).
multi-host 학습
여러 machine — 각자 multi-GPU. jax.distributed.initialize() 로 cluster 설정:
import jax
jax.distributed.initialize(
coordinator_address="10.0.0.1:1234",
num_processes=4,
process_id=process_idx, # 각 host 에서 다른 값
)
# 이제 모든 host 의 device 가 jax.devices() 에 보임
# pmap / sharding 이 자동으로 cross-host 통신 처리
SLURM, Kubernetes 등에서 launching — 각 framework 의 best practice.
🔋 hardware 별 sweet spot
(1) 학습 / research — TPU 가 단일 작업당 cost-effective 한 경우 많음 (Google Cloud). (2) inference / deployment — GPU 가 더 친숙한 경우 많음 (NVIDIA Triton 등 generic infrastructure). (3) on-device / mobile — Apple Silicon 의 jax-metal 또는 jax2tf → TF Lite. (4) CPU 만 — 작은 모델, 학습 / inference 가능 — hardware 부족할 때 시작.
같은 코드가 — hardware 만 바꾸면 — 그대로 도는 게 JAX 의 약속. 100% 자동은 아니고 — async dispatch, memory 관리 등 — hardware 별 약간의 tuning.