C.W.K.
Stream
Lesson 01 of 06 · published

JAX 가 NN 라이브러리를 안 갖는 이유

~8 min · neural-nets, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX core 에는 nn.Linear, nn.Conv2d 같은 게 없어. 처음 보면 의아하지만 — 의도된 설계야.

JAX 의 핵심 — 수치 primitive (jax.numpy) + 변환 (jit, grad, vmap, pmap). 거기까지가 core. NN 추상화는 — 따로 pick.

왜 이런 설계?

1. NN 추상화는 "어떤 패러다임이 옳은가" 에 답이 갈림

  • Mutable state (PyTorch 식): self.weight = ... 같은 instance attribute. 직관적.
  • Pure functional (Haiku 식): paramsapply_fn 분리. JAX-native.
  • Module as pytree (Equinox 식): model 자체가 pytree, jit/grad 가 직접 변환.
  • NNX (Flax 의 새 API): PyTorch 식 mutable + JAX 의 transform 잘 호환.

한 표준 강요하면 — 다른 패러다임이 막힘. JAX 는 의도적으로 비워 둠.

2. 다양한 분야가 다른 추상을 원함

  • RL — state machine 추상화 강조
  • scientific — ODE solver, simulator 와 자연스러운 연동
  • vision — CNN, transformer 표준
  • NLP — transformer, attention 표준

각 분야가 자기 라이브러리를 만들 자유.

3. 유지보수 부담 분리

NN API 는 변화가 빠름 (transformer 변형, attention 종류, normalization 변형). core 에 두면 — JAX 의 안정성 vs NN의 진화 사이의 충돌. 분리하면 — JAX core 는 천천히 안정적, NN library 는 빠르게 진화.

현재 주요 라이브러리

이름스타일주력 사용처
Flax NNXmutable Python stateGoogle, DeepMind 의 새 표준
Equinoxmodel = pytree, pure functional학술 연구, JAX-native 선호
Haikutransform 기반 (옛 Flax)DeepMind 의 옛 코드, AlphaFold
Penzai전체 모델 visualizationresearch debugging
Levantertraining scale 특화large-scale training

🎯 어느 걸 골라야 하나

(1) 새 프로젝트 + Google/DeepMind 영향권 — Flax NNX. 가장 적극적 발전. (2) 학술 연구 / functional 선호 — Equinox. JAX 의 정신과 가장 일치. (3) AlphaFold 같은 구식 코드 봐야 함 — Haiku. 그러나 새 코드는 안 추천. (4) 처음 배우면 — Flax NNX 가 PyTorch 와 가장 비슷한 ergonomics 라 진입 장벽 낮음.

이 quest 는 — Flax NNX 와 Equinox 둘 다 다룸 (10-2, 10-3). 같은 model 을 두 라이브러리로 작성해서 차이를 직접 봐.

중요한 한 가지 — 어느 라이브러리를 선택하든 — JAX 의 핵심 (jit, grad, vmap, pytree) 은 그대로. NN library 는 그 위의 syntactic sugar. 이 quest 1-9 가 다 이해된 사람은 — NN library 의 선택이 비교적 작은 결정.

Code

# PyTorch: one way to define a model
# class Model(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.linear = nn.Linear(784, 10)
#     def forward(self, x):
#         return self.linear(x)

# JAX: you choose your library
# Flax NNX version:
from flax import nnx
class Model(nnx.Module):
    def __init__(self, rngs):
        self.linear = nnx.Linear(784, 10, rngs=rngs)
    def __call__(self, x):
        return self.linear(x)

# Equinox version:
import equinox as eqx
class Model(eqx.Module):
    linear: eqx.nn.Linear
    def __init__(self, key):
        self.linear = eqx.nn.Linear(784, 10, key=key)
    def __call__(self, x):
        return self.linear(x)

External links

Exercise

공식 source 3 곳에서 'JAX core 에 NN 라이브러리 없는 이유' 의 다른 의견 읽기. 'Flax/Equinox 를 따로 배워야 하는 이유?' 묻는 동료에게 100 단어 개인 답 작성. 이 story 가 — 나중에 어떤 거 고를지 좌우.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.