Equinox 는 다른 철학 — model 은 pytree. nn.Module 같은 거 없이, 그냥 dataclass 가 곧 model.
pip install equinox
import equinox as eqx
import jax
import jax.numpy as jnp
# ============ 단일 Linear ============
class Linear(eqx.Module):
weight: jnp.ndarray
bias: jnp.ndarray
def __init__(self, in_dim, out_dim, key):
wkey, bkey = jax.random.split(key, 2)
self.weight = jax.random.normal(wkey, (out_dim, in_dim)) * 0.01
self.bias = jnp.zeros(out_dim)
def __call__(self, x):
return self.weight @ x + self.bias
key = jax.random.PRNGKey(0)
layer = Linear(10, 5, key)
x = jnp.zeros(10)
y = layer(x)
print(y.shape) # (5,)
eqx.Module = @dataclass + 자동 pytree 등록. 그게 다.
# model 자체가 pytree
print(jax.tree.leaves(layer))
# [array(weight), array(bias)]
# tree.map 자유롭게
zeroed = jax.tree.map(jnp.zeros_like, layer)
다층 MLP
class MLP(eqx.Module):
layers: list # List[Linear]
def __init__(self, dims, key):
keys = jax.random.split(key, len(dims) - 1)
self.layers = [Linear(d_in, d_out, k)
for d_in, d_out, k in zip(dims[:-1], dims[1:], keys)]
def __call__(self, x):
for layer in self.layers[:-1]:
x = jax.nn.relu(layer(x))
return self.layers[-1](x)
key = jax.random.PRNGKey(0)
model = MLP([784, 128, 64, 10], key)
y = model(jnp.zeros(784))
학습
Equinox model 은 pytree 라 — 모든 JAX 변환이 그대로 작동:
def loss_fn(model, x, y):
pred = jax.vmap(model)(x) # batch
return jnp.mean((pred - y) ** 2)
@jax.jit
def train_step(model, x, y, lr):
loss, grads = jax.value_and_grad(loss_fn)(model, x, y)
new_model = jax.tree.map(lambda p, g: p - lr * g, model, grads)
return new_model, loss
# loop
for step in range(100):
model, loss = train_step(model, batch_x, batch_y, 0.01)
특이 점이 없음 — nnx.split 같은 ceremony 없이, model 이 그냥 pytree 라 jit/grad 가 직접 처리.
filter / partition — trainable 과 frozen 분리
model = MLP([784, 10], key)
# 모든 param 이 변경 가능 (default)
# 일부만 학습 — eqx.filter 사용
def loss_with_frozen(diff_model, static_model, x, y):
model = eqx.combine(diff_model, static_model)
return loss_fn(model, x, y)
# layer 0 은 freeze, layer 1 만 학습
diff_model, static_model = eqx.partition(model,
lambda m: True if isinstance(m, Linear) and m is model.layers[1] else False
)
grads = jax.grad(loss_with_frozen)(diff_model, static_model, x, y)
built-in layer
model = eqx.nn.Sequential([
eqx.nn.Linear(784, 128, key=k1),
eqx.nn.Lambda(jax.nn.relu),
eqx.nn.Linear(128, 10, key=k2),
])
# attention block
attn = eqx.nn.MultiheadAttention(
num_heads=8, query_size=64, key=k3,
)
🌿 Equinox 의 정신
"model 도 그냥 데이터" 라는 JAX 의 철학을 가장 충실히. nn.Module 같은 magic 한 prototyping 없이 — eqx.Module 은 그냥 dataclass + pytree 등록. 결과: 모든 JAX 변환이 부담 없이 호환. 단점은 — PyTorch 의 self.x = ... mutation 패턴이 안 됨 (그게 의도). 학습은 — 새 model 객체를 매 step 만들어서 갱신.
NNX 와 Equinox 의 선택 — 팀 / 프로젝트 / 코드 스타일 취향. 둘 다 production-ready. JAX core 가 같으니 — 한 쪽 익히면 다른 쪽도 빠르게 따라잡음.