현대 ML 의 unit cell — Transformer block. multi-head attention + MLP + residual + layer norm. JAX 로 만들기.
import jax
import jax.numpy as jnp
from flax import nnx
class TransformerBlock(nnx.Module):
def __init__(self, dim, num_heads, mlp_ratio=4, *, rngs):
self.ln1 = nnx.LayerNorm(dim, rngs=rngs)
self.attn = nnx.MultiHeadAttention(
num_heads=num_heads,
in_features=dim,
rngs=rngs,
)
self.ln2 = nnx.LayerNorm(dim, rngs=rngs)
self.mlp = nnx.Sequential([
nnx.Linear(dim, mlp_ratio * dim, rngs=rngs),
lambda x: jax.nn.gelu(x),
nnx.Linear(mlp_ratio * dim, dim, rngs=rngs),
])
def __call__(self, x, mask=None):
# x: (B, L, D)
# pre-norm style (GPT, Llama 등 표준)
h = self.ln1(x)
h = self.attn(query=h, key=h, value=h, mask=mask)
x = x + h # residual
h = self.ln2(x)
h = self.mlp(h)
x = x + h # residual
return x
# 사용
block = TransformerBlock(dim=512, num_heads=8, rngs=nnx.Rngs(0))
x = jnp.zeros((4, 64, 512)) # batch=4, seq=64, dim=512
y = block(x)
print(y.shape) # (4, 64, 512)
self-attention 의 분해
위는 NNX 의 built-in attention. 직접 구현하면:
class SelfAttention(nnx.Module):
def __init__(self, dim, num_heads, *, rngs):
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nnx.Linear(dim, 3 * dim, rngs=rngs)
self.proj = nnx.Linear(dim, dim, rngs=rngs)
def __call__(self, x, mask=None):
B, L, D = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(B, L, 3, self.num_heads, self.head_dim)
qkv = qkv.transpose(2, 0, 3, 1, 4) # (3, B, H, L, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# attention scores
scale = self.head_dim ** -0.5
scores = jnp.einsum("bhqd,bhkd->bhqk", q, k) * scale
if mask is not None:
scores = jnp.where(mask, scores, -1e9)
attn = jax.nn.softmax(scores, axis=-1)
# weighted sum
out = jnp.einsum("bhqk,bhkd->bhqd", attn, v)
out = out.transpose(0, 2, 1, 3).reshape(B, L, D)
return self.proj(out)
einsum 으로 — attention 의 수식이 코드에 직접 들어옴. jnp.einsum("bhqd,bhkd->bhqk", q, k) = $Q K^T$ 의 batch / head version.
causal mask
def causal_mask(seq_len):
'''L x L lower triangular, autoregressive 생성용'''
return jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))
mask = causal_mask(64)
y = block(x, mask=mask) # mask 가 적용된 attention
positional encoding
# 학습 가능한 positional embedding
class GPTModel(nnx.Module):
def __init__(self, vocab, dim, num_layers, num_heads, max_len, *, rngs):
self.tok_emb = nnx.Embed(vocab, dim, rngs=rngs)
self.pos_emb = nnx.Param(
jax.random.normal(rngs.params(), (max_len, dim)) * 0.01
)
self.blocks = [
TransformerBlock(dim, num_heads, rngs=rngs)
for _ in range(num_layers)
]
self.ln_f = nnx.LayerNorm(dim, rngs=rngs)
self.head = nnx.Linear(dim, vocab, rngs=rngs)
def __call__(self, tokens):
# tokens: (B, L)
L = tokens.shape[1]
x = self.tok_emb(tokens) + self.pos_emb.value[:L]
mask = causal_mask(L)
for block in self.blocks:
x = block(x, mask=mask)
x = self.ln_f(x)
return self.head(x) # (B, L, vocab)
model = GPTModel(vocab=10000, dim=512, num_layers=6, num_heads=8, max_len=128, rngs=nnx.Rngs(0))
tokens = jnp.zeros((2, 64), dtype=jnp.int32)
logits = model(tokens)
print(logits.shape) # (2, 64, 10000)
# parameter count
n_params = sum(x.size for x in jax.tree.leaves(nnx.state(model, nnx.Param)))
print(f"{n_params:,} parameters") # ~ 19M
🌟 LLM 의 unit cell
지금 만든 GPTModel — 1 줄 줄이고 1 단어 늘리면 진짜 GPT 스타일 LM. dim 키우고 layers 키우면 — Llama 같은 거. 토큰화 + 학습 데이터 + 적절한 학습 설정만 추가하면 — 작은 LM 학습 가능. JAX 의 표현력으로는 — 학생이 1 주에 만들 수 있는 코드.
이 quest 의 마지막 GPT-flavored model 이 — 지금 GPT-4, Claude, Gemini 같은 모델의 핵심 구조. 차이는 — 데이터 양, computing power, 미세 조정 trick. 코드 자체는 — 같은 패턴.