C.W.K.
Stream
Lesson 03 of 06 · published

Equinox — 모델은 곧 Pytree

~10 min · neural-nets, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

Equinox 는 다른 철학 — model 은 pytree. nn.Module 같은 거 없이, 그냥 dataclass 가 곧 model.

pip install equinox
import equinox as eqx
import jax
import jax.numpy as jnp

# ============ 단일 Linear ============
class Linear(eqx.Module):
    weight: jnp.ndarray
    bias: jnp.ndarray

    def __init__(self, in_dim, out_dim, key):
        wkey, bkey = jax.random.split(key, 2)
        self.weight = jax.random.normal(wkey, (out_dim, in_dim)) * 0.01
        self.bias = jnp.zeros(out_dim)

    def __call__(self, x):
        return self.weight @ x + self.bias

key = jax.random.PRNGKey(0)
layer = Linear(10, 5, key)

x = jnp.zeros(10)
y = layer(x)
print(y.shape)   # (5,)

eqx.Module = @dataclass + 자동 pytree 등록. 그게 다.

# model 자체가 pytree
print(jax.tree.leaves(layer))
# [array(weight), array(bias)]

# tree.map 자유롭게
zeroed = jax.tree.map(jnp.zeros_like, layer)

다층 MLP

class MLP(eqx.Module):
    layers: list   # List[Linear]

    def __init__(self, dims, key):
        keys = jax.random.split(key, len(dims) - 1)
        self.layers = [Linear(d_in, d_out, k)
                       for d_in, d_out, k in zip(dims[:-1], dims[1:], keys)]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x)

key = jax.random.PRNGKey(0)
model = MLP([784, 128, 64, 10], key)
y = model(jnp.zeros(784))

학습

Equinox model 은 pytree 라 — 모든 JAX 변환이 그대로 작동:

def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)   # batch
    return jnp.mean((pred - y) ** 2)

@jax.jit
def train_step(model, x, y, lr):
    loss, grads = jax.value_and_grad(loss_fn)(model, x, y)
    new_model = jax.tree.map(lambda p, g: p - lr * g, model, grads)
    return new_model, loss

# loop
for step in range(100):
    model, loss = train_step(model, batch_x, batch_y, 0.01)

특이 점이 없음 — nnx.split 같은 ceremony 없이, model 이 그냥 pytree 라 jit/grad 가 직접 처리.

filter / partition — trainable 과 frozen 분리

model = MLP([784, 10], key)

# 모든 param 이 변경 가능 (default)
# 일부만 학습 — eqx.filter 사용
def loss_with_frozen(diff_model, static_model, x, y):
    model = eqx.combine(diff_model, static_model)
    return loss_fn(model, x, y)

# layer 0 은 freeze, layer 1 만 학습
diff_model, static_model = eqx.partition(model,
    lambda m: True if isinstance(m, Linear) and m is model.layers[1] else False
)

grads = jax.grad(loss_with_frozen)(diff_model, static_model, x, y)

built-in layer

model = eqx.nn.Sequential([
    eqx.nn.Linear(784, 128, key=k1),
    eqx.nn.Lambda(jax.nn.relu),
    eqx.nn.Linear(128, 10, key=k2),
])

# attention block
attn = eqx.nn.MultiheadAttention(
    num_heads=8, query_size=64, key=k3,
)

🌿 Equinox 의 정신

"model 도 그냥 데이터" 라는 JAX 의 철학을 가장 충실히. nn.Module 같은 magic 한 prototyping 없이 — eqx.Module 은 그냥 dataclass + pytree 등록. 결과: 모든 JAX 변환이 부담 없이 호환. 단점은 — PyTorch 의 self.x = ... mutation 패턴이 안 됨 (그게 의도). 학습은 — 새 model 객체를 매 step 만들어서 갱신.

NNX 와 Equinox 의 선택 — 팀 / 프로젝트 / 코드 스타일 취향. 둘 다 production-ready. JAX core 가 같으니 — 한 쪽 익히면 다른 쪽도 빠르게 따라잡음.

Code

import equinox as eqx
import jax
import jax.numpy as jnp

class MLP(eqx.Module):
    layers: list
    dropout: eqx.nn.Dropout

    def __init__(self, in_dim, hidden_dim, out_dim, *, key):
        k1, k2, k3 = jax.random.split(key, 3)
        self.layers = [
            eqx.nn.Linear(in_dim, hidden_dim, key=k1),
            eqx.nn.Linear(hidden_dim, out_dim, key=k2),
        ]
        self.dropout = eqx.nn.Dropout(p=0.2)

    def __call__(self, x, key=None):
        x = self.layers[0](x)
        x = jax.nn.relu(x)
        x = self.dropout(x, key=key)
        x = self.layers[1](x)
        return x

# Create model — key for parameter initialization
model = MLP(784, 256, 10, key=jax.random.key(0))

# Call it directly
x = jnp.ones((784,))
y = model(x, key=jax.random.key(1))
print(y.shape)  # (10,)
@eqx.filter_jit
@eqx.filter_grad
def compute_loss(model, x, y):
    pred = jax.vmap(model)(x)
    return jnp.mean((pred - y) ** 2)

# filter_grad automatically differentiates only w.r.t. arrays,
# leaving static fields (dropout rate, etc.) untouched
grads = compute_loss(model, x_batch, y_batch)

# Update with optax
import optax
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

# Apply updates
updates, opt_state = optimizer.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
# Split model into trainable and static parts
params, static = eqx.partition(model, eqx.is_array)
# params: same tree structure, but non-arrays replaced with None
# static: same tree structure, but arrays replaced with None

# Recombine
model = eqx.combine(params, static)

# Freeze specific layers
def freeze_first_layer(model):
    # Use tree_at to target specific parts
    filter_spec = jax.tree.map(lambda _: True, model)
    filter_spec = eqx.tree_at(
        lambda m: m.layers[0],
        filter_spec,
        replace=jax.tree.map(lambda _: False, model.layers[0])
    )
    return filter_spec

External links

Exercise

10-2 와 같은 task — Equinox 로: Linear 만들기, 1 step 학습. API 비교 — Equinox 의 'model 이 그냥 pytree' vs Flax 의 mutable state 모델. 새 research 시작할 거 선택 + 이유 적기.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.