2024 년 Google 이 JAX AI Stack 발표 — JAX 기반 ML 도구의 통합 패키지. Flax NNX (model), Optax (optimizer), Orbax (checkpoint), Grain (data loader), 기타. 한 번에 설치:
pip install jax-ai-stack
새 프로젝트 시작 시 — 이 한 줄로 verified-compatible version 의 라이브러리 묶음 설치. 버전 충돌 안 남.
Pallas — JAX 의 custom kernel 언어
표준 JAX op 으로 표현 안 되는 — 매우 특수한 hardware 친화 kernel. JAX 의 PyTorch Triton 비슷.
from jax.experimental import pallas as pl
@pl.pallas_call(out_shape=jax.ShapeDtypeStruct((1024,), jnp.float32))
def add_kernel(x_ref, y_ref, out_ref):
'''단순 vector add — Pallas 식으로'''
x = x_ref[...]
y = y_ref[...]
out_ref[...] = x + y
# 사용 — 표준 JAX 처럼
x = jnp.ones(1024)
y = jnp.ones(1024)
result = add_kernel(x, y)
위 — single thread block 의 vector add. 실전 — 더 복잡한 fused kernel, 특히 attention 변형:
# Flash Attention — Pallas 로 구현 (의사 코드)
@pl.pallas_call(...)
def flash_attention_kernel(q_ref, k_ref, v_ref, out_ref):
# block-wise loading
# softmax 의 online algorithm
# weighted sum 를 in-place 누적
...
왜 Pallas 가 필요한가:
- memory hierarchy 활용 — HBM ↔ SRAM 의 명시적 제어. attention 의 IO 비용이 메모리 hierarchy 에 따라 결정.
- fusion 의 한계 돌파 — XLA 의 자동 fusion 이 안 잡는 패턴.
- 새 hardware 활용 — TPU 의 sparse core, GPU 의 tensor core 의 직접 활용.
실전 — Pallas 가 빛나는 곳
| 분야 | Pallas 의 역할 |
|---|---|
| Flash Attention | 긴 sequence 의 attention 메모리 감소 |
| Sparse layers | sparse matrix multiplication 의 효율 kernel |
| Quantization | int8 / fp8 행렬 연산 |
| MoE (Mixture of Experts) | scatter/gather 의 효율 패턴 |
| LSH / approximate ops | 해시 기반 sampling, near-neighbor |
대부분의 사용자 — Pallas 가 필요 없어. 표준 JAX op 으로 충분. Pallas 가 결정적인 경우 — 매우 큰 모델, 특수 hardware 활용, 새 algorithm research.
JAX 진영의 다른 보조 라이브러리
- chex — 단위 test, 디버깅, dataclass
- clu (Common Loop Utils) — metric 누적, summary writer
- einshape / einops — tensor shape 조작
- xpilot — distributed coordinator (Google 내부 → 외부 release)
- distrax — probability distribution (DeepMind)
🛠 production stack 추천
(1) framework: Flax NNX 또는 Equinox. (2) optimizer: Optax. (3) checkpoint: Orbax. (4) data: Grain (또는 PyTorch DataLoader). (5) logging: wandb. (6) test: chex. (7) 기타: jax-ai-stack 의 의존성. Pallas 는 — 표준 op 이 막힐 때만. 이 stack 으로 — Llama-class 모델 학습 가능.
Pallas 의 위치 — 비유하자면 JAX 의 inline assembly. 99% 의 코드는 안 쓰지만 — 그 1% 가 결정적일 수 있어. 알아둘 만한 정도.