실전 NN 에서 자주 쓰는 layer — 대부분 NNX / Equinox 가 제공. 외워둘 만한 것:
Linear / Dense
# Flax NNX
linear = nnx.Linear(in_features=784, out_features=128, rngs=nnx.Rngs(0))
# Equinox
linear = eqx.nn.Linear(784, 128, key=k)
Conv2d
# NNX
conv = nnx.Conv(
in_features=3,
out_features=64,
kernel_size=(3, 3),
strides=(1, 1),
padding="SAME",
rngs=nnx.Rngs(0),
)
# Equinox
conv = eqx.nn.Conv2d(3, 64, kernel_size=3, padding=1, key=k)
LayerNorm
# NNX
ln = nnx.LayerNorm(num_features=128, rngs=nnx.Rngs(0))
# Equinox
ln = eqx.nn.LayerNorm((128,))
Dropout
# NNX — training flag
class MyModel(nnx.Module):
def __init__(self, *, rngs):
self.dense = nnx.Linear(128, 64, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
def __call__(self, x, *, training=True):
x = self.dense(x)
x = self.dropout(x, deterministic=not training)
return jax.nn.relu(x)
# Equinox — key 로 randomness
class MyModel(eqx.Module):
dense: eqx.nn.Linear
dropout: eqx.nn.Dropout
def __init__(self, key):
k1, k2 = jax.random.split(key)
self.dense = eqx.nn.Linear(128, 64, key=k1)
self.dropout = eqx.nn.Dropout(p=0.5)
def __call__(self, x, key):
x = self.dense(x)
x = self.dropout(x, key=key)
return jax.nn.relu(x)
Multi-Head Attention
# NNX
attn = nnx.MultiHeadAttention(
num_heads=8,
in_features=512,
rngs=nnx.Rngs(0),
)
out = attn(query=x, key=x, value=x) # self-attention
# Equinox
attn = eqx.nn.MultiheadAttention(
num_heads=8,
query_size=512,
key=k,
)
out = attn(x, x, x)
Embedding
# NNX
emb = nnx.Embed(num_embeddings=10000, features=512, rngs=nnx.Rngs(0))
y = emb(token_ids) # token_ids: (B, L) → y: (B, L, 512)
# Equinox
emb = eqx.nn.Embedding(num_embeddings=10000, embedding_size=512, key=k)
Activation 은 함수
jax.nn.relu(x)
jax.nn.gelu(x)
jax.nn.silu(x) # SwiGLU 의 일부
jax.nn.softmax(x, axis=-1)
jax.nn.log_softmax(x, axis=-1)
jax.nn.tanh(x)
jax.nn.sigmoid(x)
activation 은 — class 가 아니고 그냥 함수. JAX 의 functional 정신.
Residual Connection
class ResBlock(nnx.Module):
def __init__(self, dim, *, rngs):
self.ln = nnx.LayerNorm(dim, rngs=rngs)
self.dense1 = nnx.Linear(dim, 4*dim, rngs=rngs)
self.dense2 = nnx.Linear(4*dim, dim, rngs=rngs)
def __call__(self, x):
residual = x
x = self.ln(x)
x = self.dense1(x)
x = jax.nn.gelu(x)
x = self.dense2(x)
return residual + x # skip connection
💡 layer 의 batched-vs-single 차이
Flax NNX 의 Linear 는 — batched input 자동 처리 (마지막 axis 로 broadcast). Equinox 의 Linear 는 — 단일 input. batch 처리는 vmap. 처음 헷갈리면 docs 한 번 보고 통일.
이 quest 의 다음 두 lesson 에서 — CNN (10-5) 과 Transformer block (10-6) 을 합성. 이 layer 들이 building block.