C.W.K.
Stream
Lesson 06 of 06 · published

Transformer Block 만들기

~12 min · neural-nets, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

현대 ML 의 unit cell — Transformer block. multi-head attention + MLP + residual + layer norm. JAX 로 만들기.

import jax
import jax.numpy as jnp
from flax import nnx

class TransformerBlock(nnx.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4, *, rngs):
        self.ln1 = nnx.LayerNorm(dim, rngs=rngs)
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=dim,
            rngs=rngs,
        )
        self.ln2 = nnx.LayerNorm(dim, rngs=rngs)
        self.mlp = nnx.Sequential([
            nnx.Linear(dim, mlp_ratio * dim, rngs=rngs),
            lambda x: jax.nn.gelu(x),
            nnx.Linear(mlp_ratio * dim, dim, rngs=rngs),
        ])

    def __call__(self, x, mask=None):
        # x: (B, L, D)
        # pre-norm style (GPT, Llama 등 표준)
        h = self.ln1(x)
        h = self.attn(query=h, key=h, value=h, mask=mask)
        x = x + h           # residual

        h = self.ln2(x)
        h = self.mlp(h)
        x = x + h           # residual

        return x

# 사용
block = TransformerBlock(dim=512, num_heads=8, rngs=nnx.Rngs(0))
x = jnp.zeros((4, 64, 512))   # batch=4, seq=64, dim=512
y = block(x)
print(y.shape)   # (4, 64, 512)

self-attention 의 분해

위는 NNX 의 built-in attention. 직접 구현하면:

class SelfAttention(nnx.Module):
    def __init__(self, dim, num_heads, *, rngs):
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.qkv = nnx.Linear(dim, 3 * dim, rngs=rngs)
        self.proj = nnx.Linear(dim, dim, rngs=rngs)

    def __call__(self, x, mask=None):
        B, L, D = x.shape
        qkv = self.qkv(x)
        qkv = qkv.reshape(B, L, 3, self.num_heads, self.head_dim)
        qkv = qkv.transpose(2, 0, 3, 1, 4)   # (3, B, H, L, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # attention scores
        scale = self.head_dim ** -0.5
        scores = jnp.einsum("bhqd,bhkd->bhqk", q, k) * scale
        if mask is not None:
            scores = jnp.where(mask, scores, -1e9)
        attn = jax.nn.softmax(scores, axis=-1)

        # weighted sum
        out = jnp.einsum("bhqk,bhkd->bhqd", attn, v)
        out = out.transpose(0, 2, 1, 3).reshape(B, L, D)
        return self.proj(out)

einsum 으로 — attention 의 수식이 코드에 직접 들어옴. jnp.einsum("bhqd,bhkd->bhqk", q, k) = $Q K^T$ 의 batch / head version.

causal mask

def causal_mask(seq_len):
    '''L x L lower triangular, autoregressive 생성용'''
    return jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))

mask = causal_mask(64)
y = block(x, mask=mask)   # mask 가 적용된 attention

positional encoding

# 학습 가능한 positional embedding
class GPTModel(nnx.Module):
    def __init__(self, vocab, dim, num_layers, num_heads, max_len, *, rngs):
        self.tok_emb = nnx.Embed(vocab, dim, rngs=rngs)
        self.pos_emb = nnx.Param(
            jax.random.normal(rngs.params(), (max_len, dim)) * 0.01
        )
        self.blocks = [
            TransformerBlock(dim, num_heads, rngs=rngs)
            for _ in range(num_layers)
        ]
        self.ln_f = nnx.LayerNorm(dim, rngs=rngs)
        self.head = nnx.Linear(dim, vocab, rngs=rngs)

    def __call__(self, tokens):
        # tokens: (B, L)
        L = tokens.shape[1]
        x = self.tok_emb(tokens) + self.pos_emb.value[:L]
        mask = causal_mask(L)
        for block in self.blocks:
            x = block(x, mask=mask)
        x = self.ln_f(x)
        return self.head(x)   # (B, L, vocab)

model = GPTModel(vocab=10000, dim=512, num_layers=6, num_heads=8, max_len=128, rngs=nnx.Rngs(0))

tokens = jnp.zeros((2, 64), dtype=jnp.int32)
logits = model(tokens)
print(logits.shape)   # (2, 64, 10000)

# parameter count
n_params = sum(x.size for x in jax.tree.leaves(nnx.state(model, nnx.Param)))
print(f"{n_params:,} parameters")   # ~ 19M

🌟 LLM 의 unit cell

지금 만든 GPTModel — 1 줄 줄이고 1 단어 늘리면 진짜 GPT 스타일 LM. dim 키우고 layers 키우면 — Llama 같은 거. 토큰화 + 학습 데이터 + 적절한 학습 설정만 추가하면 — 작은 LM 학습 가능. JAX 의 표현력으로는 — 학생이 1 주에 만들 수 있는 코드.

이 quest 의 마지막 GPT-flavored model 이 — 지금 GPT-4, Claude, Gemini 같은 모델의 핵심 구조. 차이는 — 데이터 양, computing power, 미세 조정 trick. 코드 자체는 — 같은 패턴.

Code

from flax import nnx
import jax
import jax.numpy as jnp

class TransformerBlock(nnx.Module):
    """A single Transformer encoder block."""
    def __init__(self, d_model: int, num_heads: int, d_ff: int,
                 dropout_rate: float, rngs: nnx.Rngs):
        # Multi-head self-attention
        self.attention = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=d_model,
            qkv_features=d_model,
            out_features=d_model,
            rngs=rngs,
        )
        # Feed-forward network
        self.ff_linear1 = nnx.Linear(d_model, d_ff, rngs=rngs)
        self.ff_linear2 = nnx.Linear(d_ff, d_model, rngs=rngs)

        # Layer normalization (pre-norm architecture)
        self.ln1 = nnx.LayerNorm(d_model, rngs=rngs)
        self.ln2 = nnx.LayerNorm(d_model, rngs=rngs)

        # Dropout
        self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)

    def __call__(self, x, mask=None):
        # Pre-norm self-attention with residual
        residual = x
        x = self.ln1(x)
        x = self.attention(x, mask=mask)
        x = self.dropout(x)
        x = residual + x

        # Pre-norm feed-forward with residual
        residual = x
        x = self.ln2(x)
        x = self.ff_linear1(x)
        x = nnx.gelu(x)
        x = self.dropout(x)
        x = self.ff_linear2(x)
        x = self.dropout(x)
        x = residual + x

        return x

class SmallTransformer(nnx.Module):
    """A small Transformer encoder with positional embeddings."""
    def __init__(self, vocab_size: int, d_model: int, num_heads: int,
                 d_ff: int, num_layers: int, max_len: int,
                 num_classes: int, rngs: nnx.Rngs):
        self.token_embed = nnx.Embed(vocab_size, d_model, rngs=rngs)
        self.pos_embed = nnx.Embed(max_len, d_model, rngs=rngs)
        self.blocks = [
            TransformerBlock(d_model, num_heads, d_ff, 0.1, rngs=rngs)
            for _ in range(num_layers)
        ]
        self.ln_final = nnx.LayerNorm(d_model, rngs=rngs)
        self.classifier = nnx.Linear(d_model, num_classes, rngs=rngs)

    def __call__(self, token_ids):
        batch_size, seq_len = token_ids.shape
        positions = jnp.arange(seq_len)[None, :]  # (1, seq_len)

        x = self.token_embed(token_ids) + self.pos_embed(positions)

        for block in self.blocks:
            x = block(x)

        x = self.ln_final(x)
        x = x[:, 0, :]  # CLS token (first position)
        return self.classifier(x)

# Create a small Transformer
model = SmallTransformer(
    vocab_size=30000, d_model=256, num_heads=8,
    d_ff=1024, num_layers=4, max_len=512,
    num_classes=5, rngs=nnx.Rngs(42),
)

# Forward pass
token_ids = jnp.ones((2, 128), dtype=jnp.int32)
logits = model(token_ids)
print(f"Output shape: {logits.shape}")  # (2, 5)

# Count parameters
num_params = sum(x.size for x in jax.tree.leaves(nnx.state(model)))
print(f"Parameters: {num_params:,}")  # ~3.5M

External links

Exercise

single Transformer block 작성 (multi-head attention + MLP + residual + LN). (batch, seq, dim) input 에 forward. shape 검증. jit 후 seq=512, dim=768 forward 측정. 모든 modern LLM 의 unit cell.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.