C.W.K.
Stream
Lesson 05 of 05 · published

2026 의 JAX vs PyTorch, 그리고 다음 길

~12 min · ecosystem, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

긴 quest 끝에서 — 솔직한 정리. 2026 년 현재 — JAX vs PyTorch 의 위치, 그리고 다음 단계는?

현재 (2026) 의 분포

분야주류비고
industrial / production MLPyTorch (~ 70%)인프라, 라이브러리, 인재 풀
학술 연구 (general)PyTorch (~ 60%) / JAX (~ 30%)분야별 다름
large language model 학습JAX (DeepMind, Google) / PyTorch (Meta, OpenAI)모두 큰 비중
scientific / differentiableJAX (~ 60%)Diffrax, NumPyro 의 영향
mobile / edgePyTorch (~ 80%) / TF (~ 20%)JAX 거의 없음
RL / roboticsJAX 가 빠르게 성장Brax, MJX

JAX 의 강점

  • functional 합성 — jit + grad + vmap + pmap 자유롭게
  • differentiable everything — physics, ODE, Bayesian 자연스럽게
  • TPU 활용 — Google Cloud 의 cost-effective 학습
  • 코드의 simplicity — magic 없는 — 함수 + pytree
  • research velocity — 새 알고리즘 prototype 빠름

JAX 의 약점

  • 인프라 / 도구 — PyTorch 보다 적음 (data loader, deploy 옵션)
  • 학습 곡선 — functional 사고, PRNG key 등 진입 장벽
  • 커뮤니티 — 작음. tutorial / Stack Overflow 답 적음
  • NN library 분산 — Flax/Equinox/Haiku — 표준 통일 안 됨
  • 모바일 — jax2tf 우회만, native 지원 거의 없음

다음 단계 — 학습 끝낸 후

1. 학술 연구 / 새 algorithm

JAX 가 강력. 추천 — 작은 model 직접 구현 (transformer scratch, 강화학습, neural ODE 등). DeepMind 의 open-source 코드를 읽기.

2. production / 큰 model 학습

JAX AI Stack — Flax NNX + Optax + Orbax. open-source LLM 학습 코드 (예: Levanter, EasyLM). 대형 모델 학습은 — 인프라 + 데이터 + computing power 의 영역. JAX 는 — 코드의 명확성 + TPU 친화로 강점.

3. scientific computing

JAX 가 가장 빛나는 곳. Diffrax (ODE), NumPyro (Bayesian), Brax (physics), JAX-MD (분자), jaxopt (최적화). 자기 분야의 라이브러리 deep-dive.

4. multimodal / vision / LLM

PyTorch 가 여전히 우세. JAX 로 가능 — 그러나 PyTorch 의 라이브러리 풍부 (transformers, diffusers 등). 한 framework 의 깊이 vs 다른 framework 의 폭의 trade-off.

5. mobile / edge / production deploy

PyTorch + ONNX 또는 PyTorch Mobile / Core ML. JAX 는 — research 단계까지. production 의 last mile 에선 PyTorch 또는 TF.

그러면 어디로?

한 가지 — 두 framework 다 알면 — 가치 큼. 한 쪽의 mental model 이 다른 쪽도 이해하게 함. JAX 의 functional + 합성 사고는 — PyTorch 코드를 더 깨끗히 짜는 데도 도움.

현재 (2026) — JAX 가 빠르게 성장 중. PyTorch 가 여전히 우세지만 — JAX 가 점점 더 많은 시나리오에서 default 가 되는 추세. 학생이 새로 배우는 framework — JAX vs PyTorch 의 비율이 매년 JAX 쪽으로 기울고 있어.

🌅 quest 의 마지막 한 마디

JAX 를 익혔다 = ML / scientific computing 의 한 큰 도구를 손에 넣었다. 두 가지를 강조: (1) 합성의 힘 — jit + grad + vmap 의 자유로운 합성이 — 다른 framework 에선 hack 인 일들을 한 줄로. (2) functional 사고 — pure function, pytree, immutable 데이터 — 처음 답답하지만 — 익으면 더 깨끗한 모델. 이 두 정신은 — JAX 를 떠나도 — 어떤 코드를 짤 때든 가치 있어.

마지막 — JAX 는 — 활발히 진화 중인 framework. 1 년 후엔 새 도구, 새 패턴이 등장. 이 quest 는 — 그 진화에 따라가는 기반 mental model. 새 라이브러리 / 패턴이 나오면 — 그것도 익혀. 그러나 — jit, grad, vmap, pytree 의 정신은 — 변하지 않을 거야.

잘 했어. 73 lessons 를 같이 걸었네. 이제 — 자기 프로젝트로. JAX 식 사고를 — 자기 분야에서 — 직접 적용해 봐. quest 끝. ownership 시작.

Code

# JAX shines when you need:
# 1. Composable transformations (grad of grad, vmap of grad)
# 2. TPU support (JAX is first-class on TPUs)
# 3. Functional programming style
# 4. Scientific computing + ML hybrid workloads
# 5. Custom differentiation rules
# 6. Research that pushes framework boundaries

# PyTorch is better when you need:
# 1. Largest community and ecosystem
# 2. Dynamic computation graphs with easy debugging
# 3. Rapid prototyping with less boilerplate
# 4. Production deployment tooling (TorchServe, etc.)
# 5. Most SOTA model implementations available first
# Key JAX milestones:
# JAX 0.5.0 (Feb 2025) — Partitionable PRNG by default
# JAX 0.6.0 (Apr 2025) — CUDA 12.8+ required, API cleanup
# JAX 0.7.0 (Jul 2025) — Migrated from GSPMD to Shardy,
#                         direct linearization, jax.P alias,
#                         minimum Python 3.11
# JAX 0.8.x (Dec 2025) — Continued stabilization
# JAX 0.9.x (Jan 2026) — Latest release as of early 2026
# Resources for continuing your JAX journey:

# Official documentation
# - jax.readthedocs.io (core JAX)
# - flax.readthedocs.io (Flax NNX)
# - optax.readthedocs.io (Optax)
# - docs.kidger.site/equinox (Equinox)
# - docs.kidger.site/diffrax (Diffrax)
# - num.pyro.ai (NumPyro)

# Learning projects to try:
# 1. Train a small Transformer on a text dataset
# 2. Implement a variational autoencoder (VAE)
# 3. Solve a differential equation with Diffrax
# 4. Build a Bayesian neural network with NumPyro
# 5. Write a differentiable physics simulation
# 6. Fine-tune a model with LoRA using Flax NNX

# Community
# - github.com/jax-ml/jax (source + issues)
# - JAX Discussions on GitHub
# - r/MachineLearning on Reddit
# - JAX Discord / Slack communities

External links

Exercise

lesson 의 한 방향 선택 (research, scientific, production, custom kernel) — 1 주 mini-project commit. single markdown: scope, success criterion, 첫 3 구체 step. quest 끝, ownership 시작.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.