C.W.K.
Stream
Lesson 04 of 06 · published

흔한 Layer 와 Building Block

~9 min · neural-nets, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

실전 NN 에서 자주 쓰는 layer — 대부분 NNX / Equinox 가 제공. 외워둘 만한 것:

Linear / Dense

# Flax NNX
linear = nnx.Linear(in_features=784, out_features=128, rngs=nnx.Rngs(0))

# Equinox
linear = eqx.nn.Linear(784, 128, key=k)

Conv2d

# NNX
conv = nnx.Conv(
    in_features=3,
    out_features=64,
    kernel_size=(3, 3),
    strides=(1, 1),
    padding="SAME",
    rngs=nnx.Rngs(0),
)

# Equinox
conv = eqx.nn.Conv2d(3, 64, kernel_size=3, padding=1, key=k)

LayerNorm

# NNX
ln = nnx.LayerNorm(num_features=128, rngs=nnx.Rngs(0))

# Equinox
ln = eqx.nn.LayerNorm((128,))

Dropout

# NNX — training flag
class MyModel(nnx.Module):
    def __init__(self, *, rngs):
        self.dense = nnx.Linear(128, 64, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)

    def __call__(self, x, *, training=True):
        x = self.dense(x)
        x = self.dropout(x, deterministic=not training)
        return jax.nn.relu(x)

# Equinox — key 로 randomness
class MyModel(eqx.Module):
    dense: eqx.nn.Linear
    dropout: eqx.nn.Dropout

    def __init__(self, key):
        k1, k2 = jax.random.split(key)
        self.dense = eqx.nn.Linear(128, 64, key=k1)
        self.dropout = eqx.nn.Dropout(p=0.5)

    def __call__(self, x, key):
        x = self.dense(x)
        x = self.dropout(x, key=key)
        return jax.nn.relu(x)

Multi-Head Attention

# NNX
attn = nnx.MultiHeadAttention(
    num_heads=8,
    in_features=512,
    rngs=nnx.Rngs(0),
)
out = attn(query=x, key=x, value=x)   # self-attention

# Equinox
attn = eqx.nn.MultiheadAttention(
    num_heads=8,
    query_size=512,
    key=k,
)
out = attn(x, x, x)

Embedding

# NNX
emb = nnx.Embed(num_embeddings=10000, features=512, rngs=nnx.Rngs(0))
y = emb(token_ids)   # token_ids: (B, L) → y: (B, L, 512)

# Equinox
emb = eqx.nn.Embedding(num_embeddings=10000, embedding_size=512, key=k)

Activation 은 함수

jax.nn.relu(x)
jax.nn.gelu(x)
jax.nn.silu(x)   # SwiGLU 의 일부
jax.nn.softmax(x, axis=-1)
jax.nn.log_softmax(x, axis=-1)
jax.nn.tanh(x)
jax.nn.sigmoid(x)

activation 은 — class 가 아니고 그냥 함수. JAX 의 functional 정신.

Residual Connection

class ResBlock(nnx.Module):
    def __init__(self, dim, *, rngs):
        self.ln = nnx.LayerNorm(dim, rngs=rngs)
        self.dense1 = nnx.Linear(dim, 4*dim, rngs=rngs)
        self.dense2 = nnx.Linear(4*dim, dim, rngs=rngs)

    def __call__(self, x):
        residual = x
        x = self.ln(x)
        x = self.dense1(x)
        x = jax.nn.gelu(x)
        x = self.dense2(x)
        return residual + x   # skip connection

💡 layer 의 batched-vs-single 차이

Flax NNX 의 Linear 는 — batched input 자동 처리 (마지막 axis 로 broadcast). Equinox 의 Linear 는 — 단일 input. batch 처리는 vmap. 처음 헷갈리면 docs 한 번 보고 통일.

이 quest 의 다음 두 lesson 에서 — CNN (10-5) 과 Transformer block (10-6) 을 합성. 이 layer 들이 building block.

Code

from flax import nnx

rngs = nnx.Rngs(0)

# Dense (Linear) layer
dense = nnx.Linear(in_features=256, out_features=128, rngs=rngs)

# Convolution
conv = nnx.Conv(
    in_features=3,        # input channels
    out_features=64,      # output channels
    kernel_size=(3, 3),   # filter size
    strides=(1, 1),
    padding='SAME',
    rngs=rngs,
)

# Embedding
embed = nnx.Embed(
    num_embeddings=10000,  # vocabulary size
    features=256,          # embedding dimension
    rngs=rngs,
)

# Normalization
layer_norm = nnx.LayerNorm(num_features=256, rngs=rngs)
batch_norm = nnx.BatchNorm(num_features=256, rngs=rngs)

# Dropout
dropout = nnx.Dropout(rate=0.1, rngs=rngs)

# Multi-head attention
attention = nnx.MultiHeadAttention(
    num_heads=8,
    in_features=256,
    qkv_features=256,
    out_features=256,
    rngs=rngs,
)
import equinox as eqx
import jax

key = jax.random.key(0)
keys = jax.random.split(key, 6)

# Dense (Linear) layer
dense = eqx.nn.Linear(in_features=256, out_features=128, key=keys[0])

# Convolution
conv = eqx.nn.Conv2d(
    in_channels=3,
    out_channels=64,
    kernel_size=3,
    stride=1,
    padding=1,
    key=keys[1],
)

# Embedding
embed = eqx.nn.Embedding(
    num_embeddings=10000,
    embedding_size=256,
    key=keys[2],
)

# Normalization
layer_norm = eqx.nn.LayerNorm(shape=(256,))
# Equinox doesn't have built-in BatchNorm (by design — it requires state)

# Dropout
dropout = eqx.nn.Dropout(p=0.1)

# Multi-head attention
attention = eqx.nn.MultiheadAttention(
    num_heads=8,
    query_size=256,
    key=keys[3],
)
# PyTorch                        → Flax NNX                → Equinox
# nn.Linear(256, 128)           → nnx.Linear(256, 128)    → eqx.nn.Linear(256, 128)
# nn.Conv2d(3, 64, 3, padding=1)→ nnx.Conv(3, 64, (3,3)) → eqx.nn.Conv2d(3, 64, 3)
# nn.Embedding(10000, 256)      → nnx.Embed(10000, 256)   → eqx.nn.Embedding(10000, 256)
# nn.LayerNorm(256)             → nnx.LayerNorm(256)      → eqx.nn.LayerNorm((256,))
# nn.Dropout(0.1)               → nnx.Dropout(rate=0.1)   → eqx.nn.Dropout(p=0.1)
# nn.MultiheadAttention(256, 8) → nnx.MultiHeadAttention  → eqx.nn.MultiheadAttention

External links

Exercise

선택한 라이브러리 (Flax NNX 또는 Equinox) 에서 작성: Linear, LayerNorm, Dropout, MultiHeadAttention. 각각 toy input 에 forward. param shape 출력. 목표 — 익숙해지기. transformer 의 building block.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.