C.W.K.
Stream
Lesson 05 of 06 · published

CNN 만들기 — 두 라이브러리 비교

~11 min · neural-nets, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

같은 model — 작은 CNN (CIFAR-10 분류) — 을 두 라이브러리로 만들어. 차이점 직접 비교.

Flax NNX 버전

from flax import nnx
import jax
import jax.numpy as jnp

class CNN(nnx.Module):
    def __init__(self, *, rngs):
        self.conv1 = nnx.Conv(3, 32, (3, 3), padding="SAME", rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, (3, 3), padding="SAME", rngs=rngs)
        self.conv3 = nnx.Conv(64, 128, (3, 3), padding="SAME", rngs=rngs)
        self.pool = lambda x: nnx.avg_pool(x, (2, 2), (2, 2))
        self.dense = nnx.Linear(128 * 4 * 4, 10, rngs=rngs)

    def __call__(self, x):
        # x: (B, H, W, C)
        x = jax.nn.relu(self.conv1(x))
        x = self.pool(x)            # (B, H/2, W/2, 32)
        x = jax.nn.relu(self.conv2(x))
        x = self.pool(x)            # (B, H/4, W/4, 64)
        x = jax.nn.relu(self.conv3(x))
        x = self.pool(x)            # (B, H/8, W/8, 128)
        x = x.reshape(x.shape[0], -1)   # flatten
        return self.dense(x)

model = CNN(rngs=nnx.Rngs(0))

# 학습 step
@nnx.jit
def train_step(model, x, y):
    def loss_fn(model):
        logits = model(x)
        loss = jnp.mean(
            -jnp.sum(jax.nn.log_softmax(logits) * y, axis=-1)
        )
        return loss

    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model)
    nnx.update(model, jax.tree.map(
        lambda p, g: p - 0.001 * g,
        nnx.state(model, nnx.Param),
        grads,
    ))
    return loss

Equinox 버전

import equinox as eqx
import jax
import jax.numpy as jnp

class CNN(eqx.Module):
    conv1: eqx.nn.Conv2d
    conv2: eqx.nn.Conv2d
    conv3: eqx.nn.Conv2d
    dense: eqx.nn.Linear

    def __init__(self, key):
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.conv1 = eqx.nn.Conv2d(3, 32, kernel_size=3, padding=1, key=k1)
        self.conv2 = eqx.nn.Conv2d(32, 64, kernel_size=3, padding=1, key=k2)
        self.conv3 = eqx.nn.Conv2d(64, 128, kernel_size=3, padding=1, key=k3)
        self.dense = eqx.nn.Linear(128 * 4 * 4, 10, key=k4)

    def __call__(self, x):
        # x: (C, H, W) — Equinox 는 channel-first, 단일 example
        x = jax.nn.relu(self.conv1(x))
        x = jax.nn.avg_pool(x[None], (1, 2, 2))[0]
        x = jax.nn.relu(self.conv2(x))
        x = jax.nn.avg_pool(x[None], (1, 2, 2))[0]
        x = jax.nn.relu(self.conv3(x))
        x = jax.nn.avg_pool(x[None], (1, 2, 2))[0]
        x = x.reshape(-1)
        return self.dense(x)

model = CNN(jax.random.PRNGKey(0))

# 학습 step — vmap 으로 batch 처리
def loss_fn(model, X, Y):
    logits = jax.vmap(model)(X)
    return jnp.mean(-jnp.sum(jax.nn.log_softmax(logits) * Y, axis=-1))

@jax.jit
def train_step(model, X, Y):
    loss, grads = jax.value_and_grad(loss_fn)(model, X, Y)
    new_model = jax.tree.map(lambda p, g: p - 0.001 * g, model, grads)
    return new_model, loss

주요 차이점

항목Flax NNXEquinox
Layer 정의__init__ 에서 mutable 자기 attrdataclass 로 명시 + __init__
Batch 처리자동 (NNX 내부 broadcasting)vmap 명시적
Image layoutNHWC (TF/Jax 표준)CHW (PyTorch 식)
학습 stepnnx.jit + nnx.updatejax.jit + tree.map
Param 추출nnx.statetree.leaves (자동)

코드 라인 수 — 비슷. ergonomics 가 다른데, 어느 게 좋다는 절대적 답이 없어.

📐 선택 기준

(1) team 이 이미 PyTorch 친화적 — Flax NNX (transition 부드러움). (2) 학술 / 논문 / functional 사고 — Equinox. (3) AlphaFold / DeepMind 옛 코드 봐야 함 — Haiku (이젠 새 코드는 안 추천). 실전 프로젝트 — 한 번 골라서 일관성. mixed 는 디버깅 어려움.

둘 다 production 에서 잘 작동. 학습 동역학 (loss curve, 정확도) 은 동일 — 그냥 코드 작성 스타일의 차이.

Code

from flax import nnx
import jax.numpy as jnp

class CNN_NNX(nnx.Module):
    def __init__(self, num_classes: int, rngs: nnx.Rngs):
        self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), padding='SAME', rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), padding='SAME', rngs=rngs)
        self.bn1 = nnx.BatchNorm(32, rngs=rngs)
        self.bn2 = nnx.BatchNorm(64, rngs=rngs)
        self.linear1 = nnx.Linear(64 * 7 * 7, 256, rngs=rngs)
        self.linear2 = nnx.Linear(256, num_classes, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)

    def __call__(self, x):
        # x: (batch, 28, 28, 1)
        x = nnx.relu(self.bn1(self.conv1(x)))
        x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nnx.relu(self.bn2(self.conv2(x)))
        x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape(x.shape[0], -1)  # flatten
        x = nnx.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x

# Create and use
model_nnx = CNN_NNX(num_classes=10, rngs=nnx.Rngs(42))
dummy = jnp.ones((4, 28, 28, 1))
out = model_nnx(dummy)
print(f"NNX output shape: {out.shape}")  # (4, 10)
import equinox as eqx
import jax
import jax.numpy as jnp

class CNN_Eqx(eqx.Module):
    conv1: eqx.nn.Conv2d
    conv2: eqx.nn.Conv2d
    ln1: eqx.nn.GroupNorm  # GroupNorm instead of BatchNorm
    ln2: eqx.nn.GroupNorm
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    dropout: eqx.nn.Dropout

    def __init__(self, num_classes: int, *, key):
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.conv1 = eqx.nn.Conv2d(1, 32, kernel_size=3, padding=1, key=k1)
        self.conv2 = eqx.nn.Conv2d(32, 64, kernel_size=3, padding=1, key=k2)
        self.ln1 = eqx.nn.GroupNorm(groups=8, channels=32)
        self.ln2 = eqx.nn.GroupNorm(groups=8, channels=64)
        self.linear1 = eqx.nn.Linear(64 * 7 * 7, 256, key=k3)
        self.linear2 = eqx.nn.Linear(256, num_classes, key=k4)
        self.dropout = eqx.nn.Dropout(p=0.5)

    def __call__(self, x, *, key=None):
        # x: (28, 28, 1) — Equinox typically works unbatched
        x = jnp.transpose(x, (2, 0, 1))  # (1, 28, 28) — channels first
        x = jax.nn.relu(self.ln1(self.conv1(x)))
        x = eqx.nn.MaxPool2d(kernel_size=2, stride=2)(x)
        x = jax.nn.relu(self.ln2(self.conv2(x)))
        x = eqx.nn.MaxPool2d(kernel_size=2, stride=2)(x)
        x = x.reshape(-1)  # flatten
        x = jax.nn.relu(self.linear1(x))
        x = self.dropout(x, key=key)
        x = self.linear2(x)
        return x

# Create and use — vmap for batching
model_eqx = CNN_Eqx(num_classes=10, key=jax.random.key(42))
dummy_single = jnp.ones((28, 28, 1))
out = model_eqx(dummy_single, key=jax.random.key(1))
print(f"Equinox single output shape: {out.shape}")  # (10,)

# Batch with vmap
keys = jax.random.split(jax.random.key(2), 4)
batch = jnp.ones((4, 28, 28, 1))
out_batch = jax.vmap(model_eqx, in_axes=(0, None))(batch, key=None)
print(f"Equinox batch output shape: {out_batch.shape}")  # (4, 10)

External links

Exercise

Flax NNX 와 Equinox 로 같은 작은 CNN (3 conv + 1 dense) 작성. 작은 합성 dataset 1 epoch 학습. 코드 줄 수, mental overhead, final loss 비교. 우월 가리지 말고 — 각자 ergonomic shape 적기.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.