긴 quest 끝에서 — 솔직한 정리. 2026 년 현재 — JAX vs PyTorch 의 위치, 그리고 다음 단계는?
현재 (2026) 의 분포
| 분야 | 주류 | 비고 |
|---|---|---|
| industrial / production ML | PyTorch (~ 70%) | 인프라, 라이브러리, 인재 풀 |
| 학술 연구 (general) | PyTorch (~ 60%) / JAX (~ 30%) | 분야별 다름 |
| large language model 학습 | JAX (DeepMind, Google) / PyTorch (Meta, OpenAI) | 모두 큰 비중 |
| scientific / differentiable | JAX (~ 60%) | Diffrax, NumPyro 의 영향 |
| mobile / edge | PyTorch (~ 80%) / TF (~ 20%) | JAX 거의 없음 |
| RL / robotics | JAX 가 빠르게 성장 | Brax, MJX |
JAX 의 강점
- functional 합성 — jit + grad + vmap + pmap 자유롭게
- differentiable everything — physics, ODE, Bayesian 자연스럽게
- TPU 활용 — Google Cloud 의 cost-effective 학습
- 코드의 simplicity — magic 없는 — 함수 + pytree
- research velocity — 새 알고리즘 prototype 빠름
JAX 의 약점
- 인프라 / 도구 — PyTorch 보다 적음 (data loader, deploy 옵션)
- 학습 곡선 — functional 사고, PRNG key 등 진입 장벽
- 커뮤니티 — 작음. tutorial / Stack Overflow 답 적음
- NN library 분산 — Flax/Equinox/Haiku — 표준 통일 안 됨
- 모바일 — jax2tf 우회만, native 지원 거의 없음
다음 단계 — 학습 끝낸 후
1. 학술 연구 / 새 algorithm
JAX 가 강력. 추천 — 작은 model 직접 구현 (transformer scratch, 강화학습, neural ODE 등). DeepMind 의 open-source 코드를 읽기.
2. production / 큰 model 학습
JAX AI Stack — Flax NNX + Optax + Orbax. open-source LLM 학습 코드 (예: Levanter, EasyLM). 대형 모델 학습은 — 인프라 + 데이터 + computing power 의 영역. JAX 는 — 코드의 명확성 + TPU 친화로 강점.
3. scientific computing
JAX 가 가장 빛나는 곳. Diffrax (ODE), NumPyro (Bayesian), Brax (physics), JAX-MD (분자), jaxopt (최적화). 자기 분야의 라이브러리 deep-dive.
4. multimodal / vision / LLM
PyTorch 가 여전히 우세. JAX 로 가능 — 그러나 PyTorch 의 라이브러리 풍부 (transformers, diffusers 등). 한 framework 의 깊이 vs 다른 framework 의 폭의 trade-off.
5. mobile / edge / production deploy
PyTorch + ONNX 또는 PyTorch Mobile / Core ML. JAX 는 — research 단계까지. production 의 last mile 에선 PyTorch 또는 TF.
그러면 어디로?
한 가지 — 두 framework 다 알면 — 가치 큼. 한 쪽의 mental model 이 다른 쪽도 이해하게 함. JAX 의 functional + 합성 사고는 — PyTorch 코드를 더 깨끗히 짜는 데도 도움.
현재 (2026) — JAX 가 빠르게 성장 중. PyTorch 가 여전히 우세지만 — JAX 가 점점 더 많은 시나리오에서 default 가 되는 추세. 학생이 새로 배우는 framework — JAX vs PyTorch 의 비율이 매년 JAX 쪽으로 기울고 있어.
🌅 quest 의 마지막 한 마디
JAX 를 익혔다 = ML / scientific computing 의 한 큰 도구를 손에 넣었다. 두 가지를 강조: (1) 합성의 힘 — jit + grad + vmap 의 자유로운 합성이 — 다른 framework 에선 hack 인 일들을 한 줄로. (2) functional 사고 — pure function, pytree, immutable 데이터 — 처음 답답하지만 — 익으면 더 깨끗한 모델. 이 두 정신은 — JAX 를 떠나도 — 어떤 코드를 짤 때든 가치 있어.
마지막 — JAX 는 — 활발히 진화 중인 framework. 1 년 후엔 새 도구, 새 패턴이 등장. 이 quest 는 — 그 진화에 따라가는 기반 mental model. 새 라이브러리 / 패턴이 나오면 — 그것도 익혀. 그러나 — jit, grad, vmap, pytree 의 정신은 — 변하지 않을 거야.
잘 했어. 73 lessons 를 같이 걸었네. 이제 — 자기 프로젝트로. JAX 식 사고를 — 자기 분야에서 — 직접 적용해 봐. quest 끝. ownership 시작.