C.W.K.
Stream
Lesson 05 of 05 · published

실전: Pytree 로 Neural Network 표현

~12 min · pytrees, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

pytree 의 모든 도구를 합쳐 — 작은 NN 을 from-scratch 로 만들고 학습.

import jax
import jax.numpy as jnp
from jax import random

# ============ 초기화 ============
def init_layer(key, in_dim, out_dim, scale=0.01):
    '''단일 layer 의 params'''
    k1, k2 = random.split(key)
    return {
        "W": random.normal(k1, (in_dim, out_dim)) * scale,
        "b": jnp.zeros(out_dim),
    }

def init_mlp(key, sizes):
    '''multi-layer perceptron 의 params'''
    keys = random.split(key, len(sizes) - 1)
    return [
        init_layer(k, in_d, out_d)
        for k, in_d, out_d in zip(keys, sizes[:-1], sizes[1:])
    ]

# 4-layer MLP: 784 → 128 → 64 → 10
key = random.PRNGKey(0)
params = init_mlp(key, [784, 128, 64, 10])

# 구조 확인
print(jax.tree.map(lambda x: x.shape, params))
# [{'W': (784, 128), 'b': (128,)},
#  {'W': (128, 64), 'b': (64,)},
#  {'W': (64, 10), 'b': (10,)}]

# 총 parameter 개수
n = sum(x.size for x in jax.tree.leaves(params))
print(f"{n:,} parameters")

params 가 — list of dict. nn.Module 같은 클래스 없음. 그냥 데이터.

forward pass

def mlp_apply(params, x):
    '''단일 example 의 forward'''
    for layer in params[:-1]:
        x = jax.nn.relu(x @ layer["W"] + layer["b"])
    # 마지막 layer 는 activation 없음 (logits)
    return x @ params[-1]["W"] + params[-1]["b"]

# 단일 input
x = jnp.zeros(784)
logits = mlp_apply(params, x)
print(logits.shape)   # (10,)

# batch — vmap
mlp_batch = jax.vmap(mlp_apply, in_axes=(None, 0))
batch_x = jnp.zeros((32, 784))
batch_logits = mlp_batch(params, batch_x)
print(batch_logits.shape)   # (32, 10)

loss + grad

def cross_entropy_loss(params, x, y):
    '''y: one-hot (10,)'''
    logits = mlp_apply(params, x)
    log_probs = jax.nn.log_softmax(logits)
    return -jnp.sum(y * log_probs)

def batch_loss(params, X, Y):
    '''X: (B, 784), Y: (B, 10)'''
    losses = jax.vmap(cross_entropy_loss, in_axes=(None, 0, 0))(params, X, Y)
    return losses.mean()

# gradient — params 와 같은 모양의 gradient pytree
grads = jax.grad(batch_loss)(params, batch_x, batch_y)
print(jax.tree.map(lambda x: x.shape, grads))
# 같은 구조: list of {"W": ..., "b": ...}

학습 step

@jax.jit
def train_step(params, X, Y, lr):
    loss, grads = jax.value_and_grad(batch_loss)(params, X, Y)
    new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

# 학습 loop
for step in range(1000):
    params, loss = train_step(params, batch_x, batch_y, 0.001)

parameter manipulation

# Weight decay — 모든 W 만 (b 제외)
def add_weight_decay(grads, params, wd=1e-4):
    def decay(g, p):
        # leaf level 에서는 W 와 b 구분 못함, dict 단위로 작업
        return g
    # 깔끔히 하려면 — layer 별
    new_grads = []
    for g, p in zip(grads, params):
        new_grads.append({
            "W": g["W"] + wd * p["W"],
            "b": g["b"],   # bias 는 decay 안 함
        })
    return new_grads

# Gradient clipping
def clip_grads(grads, max_norm=1.0):
    leaves = jax.tree.leaves(grads)
    total_norm = jnp.sqrt(sum(jnp.sum(l**2) for l in leaves))
    scale = jnp.minimum(1.0, max_norm / (total_norm + 1e-6))
    return jax.tree.map(lambda g: g * scale, grads)

🌳 pytree NN 의 전체 그림

(1) params 는 pytree (list, dict, dataclass — 자유). (2) forward 는 함수 (params, x) → output. (3) loss 는 함수 (params, X, Y) → scalar. (4) gradient 는 jax.grad 로 자동, params 와 같은 모양의 pytree. (5) update 는 jax.tree.map 로 element-wise. (6) jit 으로 compile. 모든 JAX NN 코드의 골격이 이 6 단계.

Track 10 에서 — Flax NNX 나 Equinox 같은 라이브러리들이 위 패턴을 어떻게 깔끔히 wrapping 하는지 봐. 그러나 underlying 은 — 항상 pytree + 함수.

Code

import jax
import jax.numpy as jnp

def init_mlp(key, layer_sizes):
    """Initialize an MLP as a pytree (list of dicts)."""
    params = []
    for i in range(len(layer_sizes) - 1):
        key, k1, k2 = jax.random.split(key, 3)
        fan_in, fan_out = layer_sizes[i], layer_sizes[i + 1]
        # Glorot/Xavier initialization
        std = jnp.sqrt(2.0 / (fan_in + fan_out))
        params.append({
            'weights': jax.random.normal(k1, (fan_in, fan_out)) * std,
            'bias': jnp.zeros(fan_out),
        })
    return params

def mlp_forward(params, x):
    """Forward pass through the MLP pytree."""
    for i, layer in enumerate(params[:-1]):
        x = x @ layer['weights'] + layer['bias']
        x = jax.nn.relu(x)
    # Last layer: no activation
    x = x @ params[-1]['weights'] + params[-1]['bias']
    return x

# Initialize
key = jax.random.key(42)
params = init_mlp(key, [784, 256, 128, 10])

# Check structure
for i, layer in enumerate(params):
    print(f"Layer {i}: weights {layer['weights'].shape}, "
          f"bias {layer['bias'].shape}")
# Layer 0: weights (784, 256), bias (256,)
# Layer 1: weights (256, 128), bias (128,)
# Layer 2: weights (128, 10), bias (10,)
def loss_fn(params, x, y):
    """Cross-entropy loss for classification."""
    logits = mlp_forward(params, x)
    log_probs = jax.nn.log_softmax(logits)
    return -jnp.mean(jnp.sum(log_probs * y, axis=-1))

@jax.jit
def train_step(params, x, y, lr=0.001):
    """Single training step — everything is pytrees."""
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    # SGD: subtract scaled gradients from every parameter
    new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

# Simulate training
x_batch = jax.random.normal(jax.random.key(0), (32, 784))
y_batch = jax.nn.one_hot(jnp.arange(32) % 10, 10)

for step in range(100):
    params, loss = train_step(params, x_batch, y_batch)
    if step % 20 == 0:
        print(f"Step {step}: loss = {loss:.4f}")
# grad returns a pytree matching the input
grads = jax.grad(loss_fn)(params, x_batch, y_batch)
# grads[0]['weights'].shape == params[0]['weights'].shape ✓

# jit works on functions that take/return pytrees
fast_forward = jax.jit(mlp_forward)
logits = fast_forward(params, x_batch)

# vmap over the batch dimension while keeping params fixed
single_predict = lambda x: mlp_forward(params, x)
batch_logits = jax.vmap(single_predict)(x_batch)

# Compute parameter norms
param_norms = jax.tree.map(jnp.linalg.norm, params)
# param_norms has same structure: list of {'weights': scalar, 'bias': scalar}

# Total L2 regularization
l2_reg = sum(jnp.sum(x ** 2) for x in jax.tree.leaves(params))
print(f"L2 norm: {jnp.sqrt(l2_reg):.2f}")
# PyTorch: parameters live inside the Module object
# class MLP(nn.Module):
#     def __init__(self):
#         self.layer1 = nn.Linear(784, 256)
#     def forward(self, x):
#         return self.layer1(x)
# model = MLP()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# JAX: parameters are a plain data structure, functions are separate
# params = init_mlp(key, [784, 256, 10])
# grads = jax.grad(loss_fn)(params, x, y)
# params = jax.tree.map(lambda p, g: p - lr * g, params, grads)

# Both work — JAX separates data from logic, PyTorch bundles them together

External links

Exercise

params 가 single nested dict (pytree) 인 2-layer MLP. 초기화, forward, MSE 계산, grad. 마지막 params dict 의 구조와 초기 — leaf-by-leaf 매칭 확인. quest 의 mental model 이 여기에 land.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.