같은 model — 작은 CNN (CIFAR-10 분류) — 을 두 라이브러리로 만들어. 차이점 직접 비교.
Flax NNX 버전
from flax import nnx
import jax
import jax.numpy as jnp
class CNN(nnx.Module):
def __init__(self, *, rngs):
self.conv1 = nnx.Conv(3, 32, (3, 3), padding="SAME", rngs=rngs)
self.conv2 = nnx.Conv(32, 64, (3, 3), padding="SAME", rngs=rngs)
self.conv3 = nnx.Conv(64, 128, (3, 3), padding="SAME", rngs=rngs)
self.pool = lambda x: nnx.avg_pool(x, (2, 2), (2, 2))
self.dense = nnx.Linear(128 * 4 * 4, 10, rngs=rngs)
def __call__(self, x):
# x: (B, H, W, C)
x = jax.nn.relu(self.conv1(x))
x = self.pool(x) # (B, H/2, W/2, 32)
x = jax.nn.relu(self.conv2(x))
x = self.pool(x) # (B, H/4, W/4, 64)
x = jax.nn.relu(self.conv3(x))
x = self.pool(x) # (B, H/8, W/8, 128)
x = x.reshape(x.shape[0], -1) # flatten
return self.dense(x)
model = CNN(rngs=nnx.Rngs(0))
# 학습 step
@nnx.jit
def train_step(model, x, y):
def loss_fn(model):
logits = model(x)
loss = jnp.mean(
-jnp.sum(jax.nn.log_softmax(logits) * y, axis=-1)
)
return loss
grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = grad_fn(model)
nnx.update(model, jax.tree.map(
lambda p, g: p - 0.001 * g,
nnx.state(model, nnx.Param),
grads,
))
return loss
Equinox 버전
import equinox as eqx
import jax
import jax.numpy as jnp
class CNN(eqx.Module):
conv1: eqx.nn.Conv2d
conv2: eqx.nn.Conv2d
conv3: eqx.nn.Conv2d
dense: eqx.nn.Linear
def __init__(self, key):
k1, k2, k3, k4 = jax.random.split(key, 4)
self.conv1 = eqx.nn.Conv2d(3, 32, kernel_size=3, padding=1, key=k1)
self.conv2 = eqx.nn.Conv2d(32, 64, kernel_size=3, padding=1, key=k2)
self.conv3 = eqx.nn.Conv2d(64, 128, kernel_size=3, padding=1, key=k3)
self.dense = eqx.nn.Linear(128 * 4 * 4, 10, key=k4)
def __call__(self, x):
# x: (C, H, W) — Equinox 는 channel-first, 단일 example
x = jax.nn.relu(self.conv1(x))
x = jax.nn.avg_pool(x[None], (1, 2, 2))[0]
x = jax.nn.relu(self.conv2(x))
x = jax.nn.avg_pool(x[None], (1, 2, 2))[0]
x = jax.nn.relu(self.conv3(x))
x = jax.nn.avg_pool(x[None], (1, 2, 2))[0]
x = x.reshape(-1)
return self.dense(x)
model = CNN(jax.random.PRNGKey(0))
# 학습 step — vmap 으로 batch 처리
def loss_fn(model, X, Y):
logits = jax.vmap(model)(X)
return jnp.mean(-jnp.sum(jax.nn.log_softmax(logits) * Y, axis=-1))
@jax.jit
def train_step(model, X, Y):
loss, grads = jax.value_and_grad(loss_fn)(model, X, Y)
new_model = jax.tree.map(lambda p, g: p - 0.001 * g, model, grads)
return new_model, loss
주요 차이점
| 항목 | Flax NNX | Equinox |
|---|---|---|
| Layer 정의 | __init__ 에서 mutable 자기 attr | dataclass 로 명시 + __init__ |
| Batch 처리 | 자동 (NNX 내부 broadcasting) | vmap 명시적 |
| Image layout | NHWC (TF/Jax 표준) | CHW (PyTorch 식) |
| 학습 step | nnx.jit + nnx.update | jax.jit + tree.map |
| Param 추출 | nnx.state | tree.leaves (자동) |
코드 라인 수 — 비슷. ergonomics 가 다른데, 어느 게 좋다는 절대적 답이 없어.
📐 선택 기준
(1) team 이 이미 PyTorch 친화적 — Flax NNX (transition 부드러움). (2) 학술 / 논문 / functional 사고 — Equinox. (3) AlphaFold / DeepMind 옛 코드 봐야 함 — Haiku (이젠 새 코드는 안 추천). 실전 프로젝트 — 한 번 골라서 일관성. mixed 는 디버깅 어려움.
둘 다 production 에서 잘 작동. 학습 동역학 (loss curve, 정확도) 은 동일 — 그냥 코드 작성 스타일의 차이.