JAX core 에는 nn.Linear, nn.Conv2d 같은 게 없어. 처음 보면 의아하지만 — 의도된 설계야.
JAX 의 핵심 — 수치 primitive (jax.numpy) + 변환 (jit, grad, vmap, pmap). 거기까지가 core. NN 추상화는 — 따로 pick.
왜 이런 설계?
1. NN 추상화는 "어떤 패러다임이 옳은가" 에 답이 갈림
- Mutable state (PyTorch 식):
self.weight = ...같은 instance attribute. 직관적. - Pure functional (Haiku 식):
params와apply_fn분리. JAX-native. - Module as pytree (Equinox 식): model 자체가 pytree, jit/grad 가 직접 변환.
- NNX (Flax 의 새 API): PyTorch 식 mutable + JAX 의 transform 잘 호환.
한 표준 강요하면 — 다른 패러다임이 막힘. JAX 는 의도적으로 비워 둠.
2. 다양한 분야가 다른 추상을 원함
- RL — state machine 추상화 강조
- scientific — ODE solver, simulator 와 자연스러운 연동
- vision — CNN, transformer 표준
- NLP — transformer, attention 표준
각 분야가 자기 라이브러리를 만들 자유.
3. 유지보수 부담 분리
NN API 는 변화가 빠름 (transformer 변형, attention 종류, normalization 변형). core 에 두면 — JAX 의 안정성 vs NN의 진화 사이의 충돌. 분리하면 — JAX core 는 천천히 안정적, NN library 는 빠르게 진화.
현재 주요 라이브러리
| 이름 | 스타일 | 주력 사용처 |
|---|---|---|
| Flax NNX | mutable Python state | Google, DeepMind 의 새 표준 |
| Equinox | model = pytree, pure functional | 학술 연구, JAX-native 선호 |
| Haiku | transform 기반 (옛 Flax) | DeepMind 의 옛 코드, AlphaFold |
| Penzai | 전체 모델 visualization | research debugging |
| Levanter | training scale 특화 | large-scale training |
🎯 어느 걸 골라야 하나
(1) 새 프로젝트 + Google/DeepMind 영향권 — Flax NNX. 가장 적극적 발전. (2) 학술 연구 / functional 선호 — Equinox. JAX 의 정신과 가장 일치. (3) AlphaFold 같은 구식 코드 봐야 함 — Haiku. 그러나 새 코드는 안 추천. (4) 처음 배우면 — Flax NNX 가 PyTorch 와 가장 비슷한 ergonomics 라 진입 장벽 낮음.
이 quest 는 — Flax NNX 와 Equinox 둘 다 다룸 (10-2, 10-3). 같은 model 을 두 라이브러리로 작성해서 차이를 직접 봐.
중요한 한 가지 — 어느 라이브러리를 선택하든 — JAX 의 핵심 (jit, grad, vmap, pytree) 은 그대로. NN library 는 그 위의 syntactic sugar. 이 quest 1-9 가 다 이해된 사람은 — NN library 의 선택이 비교적 작은 결정.