"근데 PyTorch 가 모든 걸 다 하는데 왜 JAX 야?" — 합리적 질문이야. 답은: 다른 패러다임이고, 각자 다른 것에 강해.
PyTorch: dynamic, eager, OOP. nn.Module 상속, forward 메서드, .backward() 호출하면 magic 처럼 gradient 가 parameter 에 붙음. 처음 배우는 사람한테 친절. 단점은 — graph 구조가 명시적이지 않아서 compiler 최적화 제한.
TensorFlow: 원래 graph-first 였는데 2.x 부터 eager mode default. 산업 인프라 (TFX, TF Serving, TF Lite, TFJS) 강력. 연구실에선 점점 덜 쓰는 추세.
JAX: functional, transformation-first. nn.Module 같은 거 없음 (Flax/Equinox 가 따로 제공). jit/grad/vmap/pmap 이 1 등 시민, 자유롭게 합성. compile-first 라서 XLA 최적화 깊음. 단점은 — 처음 진입 장벽이 있어 (functional 사고, pure function, PRNG key 등).
| PyTorch | TensorFlow | JAX
---------|------------------|------------------|------------------
스타일 | OOP, eager | OOP/graph hybrid | functional
미분 | tensor.backward()| tape 기반 | jax.grad (함수)
batch | 손으로 처리 | 손으로 처리 | jax.vmap
multi-GPU| DDP / FSDP | tf.distribute | pmap / sharding
연구 | 1 위 (분야 다수) | 감소 | 상승 (DeepMind 등)
🧭 어느 걸 골라야 하나
처음이면 PyTorch 부터. 큰 회사 production 이면 TF 도 여전히. 연구 — 특히 functional / transform-heavy / TPU 활용 — 면 JAX. 답은 "둘 다 알면 좋음" 인데, 이 quest 는 JAX 의 왜 를 가르치는 게 목표야.
중요한 건 — JAX 가 PyTorch 를 죽이러 온 게 아니야. 같은 문제에 다른 답을 제시하는 거. 두 답이 공존하는 시대를 살고 있어.