Flax NNX 는 — Flax 의 새 API. 2024 부터 권장. mutable Python state 로 PyTorch 처럼 친근하지만 — JAX transform 과 잘 호환.
pip install flax
from flax import nnx
import jax
import jax.numpy as jnp
# ============ 단일 Linear layer ============
class Linear(nnx.Module):
def __init__(self, in_dim, out_dim, *, rngs):
self.W = nnx.Param(jax.random.normal(rngs.params(), (in_dim, out_dim)) * 0.01)
self.b = nnx.Param(jnp.zeros(out_dim))
def __call__(self, x):
return x @ self.W + self.b
# 사용
rngs = nnx.Rngs(0) # PRNG state
layer = Linear(10, 5, rngs=rngs)
x = jnp.zeros(10)
y = layer(x)
print(y.shape) # (5,)
# parameter 접근
print(layer.W.value.shape) # (10, 5)
다층 model
class MLP(nnx.Module):
def __init__(self, dims, *, rngs):
self.layers = [
Linear(dims[i], dims[i+1], rngs=rngs)
for i in range(len(dims) - 1)
]
def __call__(self, x):
for layer in self.layers[:-1]:
x = jax.nn.relu(layer(x))
return self.layers[-1](x)
model = MLP([784, 128, 64, 10], rngs=nnx.Rngs(0))
y = model(jnp.zeros(784))
PyTorch 처럼 보이지? 차이는 — JAX transform 이 잘 호환:
@nnx.jit
def train_step(model, x, y):
def loss_fn(model):
logits = model(x)
return jnp.mean((logits - y) ** 2)
grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = grad_fn(model)
nnx.update(model, jax.tree.map(lambda p, g: p - 0.01 * g,
nnx.state(model, nnx.Param),
grads))
return loss
state 분리
NNX 의 핵심 — model 의 mutable state 와 정적 구조를 분리. nnx.state, nnx.update, nnx.split, nnx.merge:
model = MLP([784, 128, 10], rngs=nnx.Rngs(0))
# state 추출 — 모든 trainable param
state = nnx.state(model, nnx.Param)
# state 는 pytree
# graphdef — model 의 구조 (변하지 않는 부분)
graphdef, state = nnx.split(model)
# 다시 합치기
restored = nnx.merge(graphdef, state)
# pure transform — graphdef 를 static, state 를 dynamic 으로
@jax.jit
def pure_step(graphdef, state, x):
model = nnx.merge(graphdef, state)
return model(x)
BatchNorm 같은 stateful layer
class MyBN(nnx.Module):
def __init__(self, dim, *, rngs):
self.gamma = nnx.Param(jnp.ones(dim))
self.beta = nnx.Param(jnp.zeros(dim))
self.running_mean = nnx.Variable(jnp.zeros(dim)) # not trainable
self.running_var = nnx.Variable(jnp.ones(dim))
def __call__(self, x, training=True):
if training:
mean = x.mean(0)
var = x.var(0)
self.running_mean.value = 0.9 * self.running_mean.value + 0.1 * mean
self.running_var.value = 0.9 * self.running_var.value + 0.1 * var
else:
mean = self.running_mean.value
var = self.running_var.value
x_norm = (x - mean) / jnp.sqrt(var + 1e-5)
return self.gamma.value * x_norm + self.beta.value
nnx.Param = trainable, nnx.Variable = non-trainable mutable. nnx.state 가 Param 만 grad 흘림.
💡 PyTorch 사용자에게
NNX 는 PyTorch 와 거의 같은 ergonomics. 한 가지 큰 차이 — model 을 함수형으로 다룰 수도 있음 (nnx.split / merge). 학습 루프는 PyTorch 처럼 짜고, 깊은 transform 이 필요하면 functional 로 내려가는 식. PyTorch 에서 JAX 로 옮기는 가장 부드러운 길.
활용 — Flax 표준 layer (nnx.Linear, nnx.Conv, nnx.LayerNorm 등) 가 풍부. 직접 안 만들고 — 라이브러리 layer 합쳐서 model 구성. Track 10-4 에서 다룸.