JAX 본체 (jax + jaxlib) 는 의도적으로 작게 유지돼. 변환 + 수치 + XLA 인터페이스 거기까지. 위에 얹혀서 실제 ML / 과학 계산을 하는 건 — 별도 라이브러리들의 생태계야.
- Flax (NNX): Google 의 JAX-native NN 라이브러리.
- Equinox: model 을 그냥 pytree 로 본다. 더 functional.
- Optax: Optimizer + gradient transformation 라이브러리. 사실상 표준.
- Orbax: checkpoint 저장 / 복원 표준.
- Diffrax: differentiable ODE/SDE solver.
- NumPyro: Bayesian inference + probabilistic programming.
- Brax: GPU/TPU 에서 도는 differentiable rigid-body physics.
설치 — 첫 단추가 platform 별로 좀 다름:
# CPU 만 (지금 시작하기 가장 쉬움)
pip install -U jax
# NVIDIA GPU (CUDA 12)
pip install -U "jax[cuda12]"
# Cloud TPU
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
설치 sanity check:
import jax
import jax.numpy as jnp
print("JAX version:", jax.__version__)
print("Devices:", jax.devices())
print("Default backend:", jax.default_backend())
x = jnp.array([1.0, 2.0, 3.0])
print(jnp.sum(x)) # 6.0
💡 conda + virtualenv 권장
JAX 는 jaxlib 의 platform 별 binary 와 fit 이 까다로워. 새 venv / conda env 하나 따로 만들어서 시작해. 이 quest 끝까지 그 환경 재사용할 거야.