C.W.K.
Stream
Lesson 02 of 06 · published

Flax NNX — 새 표준

~9 min · neural-nets, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

Flax NNX 는 — Flax 의 새 API. 2024 부터 권장. mutable Python state 로 PyTorch 처럼 친근하지만 — JAX transform 과 잘 호환.

pip install flax
from flax import nnx
import jax
import jax.numpy as jnp

# ============ 단일 Linear layer ============
class Linear(nnx.Module):
    def __init__(self, in_dim, out_dim, *, rngs):
        self.W = nnx.Param(jax.random.normal(rngs.params(), (in_dim, out_dim)) * 0.01)
        self.b = nnx.Param(jnp.zeros(out_dim))

    def __call__(self, x):
        return x @ self.W + self.b

# 사용
rngs = nnx.Rngs(0)   # PRNG state
layer = Linear(10, 5, rngs=rngs)

x = jnp.zeros(10)
y = layer(x)
print(y.shape)   # (5,)

# parameter 접근
print(layer.W.value.shape)   # (10, 5)

다층 model

class MLP(nnx.Module):
    def __init__(self, dims, *, rngs):
        self.layers = [
            Linear(dims[i], dims[i+1], rngs=rngs)
            for i in range(len(dims) - 1)
        ]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x)

model = MLP([784, 128, 64, 10], rngs=nnx.Rngs(0))
y = model(jnp.zeros(784))

PyTorch 처럼 보이지? 차이는 — JAX transform 이 잘 호환:

@nnx.jit
def train_step(model, x, y):
    def loss_fn(model):
        logits = model(x)
        return jnp.mean((logits - y) ** 2)

    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model)

    nnx.update(model, jax.tree.map(lambda p, g: p - 0.01 * g,
                                    nnx.state(model, nnx.Param),
                                    grads))
    return loss

state 분리

NNX 의 핵심 — model 의 mutable state 와 정적 구조를 분리. nnx.state, nnx.update, nnx.split, nnx.merge:

model = MLP([784, 128, 10], rngs=nnx.Rngs(0))

# state 추출 — 모든 trainable param
state = nnx.state(model, nnx.Param)
# state 는 pytree

# graphdef — model 의 구조 (변하지 않는 부분)
graphdef, state = nnx.split(model)

# 다시 합치기
restored = nnx.merge(graphdef, state)

# pure transform — graphdef 를 static, state 를 dynamic 으로
@jax.jit
def pure_step(graphdef, state, x):
    model = nnx.merge(graphdef, state)
    return model(x)

BatchNorm 같은 stateful layer

class MyBN(nnx.Module):
    def __init__(self, dim, *, rngs):
        self.gamma = nnx.Param(jnp.ones(dim))
        self.beta = nnx.Param(jnp.zeros(dim))
        self.running_mean = nnx.Variable(jnp.zeros(dim))   # not trainable
        self.running_var = nnx.Variable(jnp.ones(dim))

    def __call__(self, x, training=True):
        if training:
            mean = x.mean(0)
            var = x.var(0)
            self.running_mean.value = 0.9 * self.running_mean.value + 0.1 * mean
            self.running_var.value = 0.9 * self.running_var.value + 0.1 * var
        else:
            mean = self.running_mean.value
            var = self.running_var.value
        x_norm = (x - mean) / jnp.sqrt(var + 1e-5)
        return self.gamma.value * x_norm + self.beta.value

nnx.Param = trainable, nnx.Variable = non-trainable mutable. nnx.state 가 Param 만 grad 흘림.

💡 PyTorch 사용자에게

NNX 는 PyTorch 와 거의 같은 ergonomics. 한 가지 큰 차이 — model 을 함수형으로 다룰 수도 있음 (nnx.split / merge). 학습 루프는 PyTorch 처럼 짜고, 깊은 transform 이 필요하면 functional 로 내려가는 식. PyTorch 에서 JAX 로 옮기는 가장 부드러운 길.

활용 — Flax 표준 layer (nnx.Linear, nnx.Conv, nnx.LayerNorm 등) 가 풍부. 직접 안 만들고 — 라이브러리 layer 합쳐서 model 구성. Track 10-4 에서 다룸.

Code

from flax import nnx
import jax.numpy as jnp

# Define a simple model — looks like PyTorch!
class MLP(nnx.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(in_dim, hidden_dim, rngs=rngs)
        self.bn = nnx.BatchNorm(hidden_dim, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.2, rngs=rngs)
        self.linear2 = nnx.Linear(hidden_dim, out_dim, rngs=rngs)

    def __call__(self, x):
        x = self.linear1(x)
        x = self.bn(x)
        x = nnx.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

# Create the model — eager initialization, no lazy shapes
model = MLP(784, 256, 10, rngs=nnx.Rngs(0))

# Call it directly — no .apply() needed!
x = jnp.ones((32, 784))
y = model(x)
print(y.shape)  # (32, 10)
# NNX transforms handle modules with state
import optax

model = MLP(784, 256, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))

# Training step — model and optimizer are mutated in-place
@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        logits = model(x)
        return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(
            logits, y))

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)
    return loss

# Use it — no return/reassign dance
# loss = train_step(model, optimizer, x_batch, y_batch)

External links

Exercise

flax 설치. nnx.Linear 만들기. random input 에 forward. parameter pytree 출력. nnx.jit 로 wrap. track 9 의 손수 만든 JAX MLP 와 비교 — 변한 것, 그대로인 것.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.