pytree 의 모든 도구를 합쳐 — 작은 NN 을 from-scratch 로 만들고 학습.
import jax
import jax.numpy as jnp
from jax import random
# ============ 초기화 ============
def init_layer(key, in_dim, out_dim, scale=0.01):
'''단일 layer 의 params'''
k1, k2 = random.split(key)
return {
"W": random.normal(k1, (in_dim, out_dim)) * scale,
"b": jnp.zeros(out_dim),
}
def init_mlp(key, sizes):
'''multi-layer perceptron 의 params'''
keys = random.split(key, len(sizes) - 1)
return [
init_layer(k, in_d, out_d)
for k, in_d, out_d in zip(keys, sizes[:-1], sizes[1:])
]
# 4-layer MLP: 784 → 128 → 64 → 10
key = random.PRNGKey(0)
params = init_mlp(key, [784, 128, 64, 10])
# 구조 확인
print(jax.tree.map(lambda x: x.shape, params))
# [{'W': (784, 128), 'b': (128,)},
# {'W': (128, 64), 'b': (64,)},
# {'W': (64, 10), 'b': (10,)}]
# 총 parameter 개수
n = sum(x.size for x in jax.tree.leaves(params))
print(f"{n:,} parameters")
params 가 — list of dict. nn.Module 같은 클래스 없음. 그냥 데이터.
forward pass
def mlp_apply(params, x):
'''단일 example 의 forward'''
for layer in params[:-1]:
x = jax.nn.relu(x @ layer["W"] + layer["b"])
# 마지막 layer 는 activation 없음 (logits)
return x @ params[-1]["W"] + params[-1]["b"]
# 단일 input
x = jnp.zeros(784)
logits = mlp_apply(params, x)
print(logits.shape) # (10,)
# batch — vmap
mlp_batch = jax.vmap(mlp_apply, in_axes=(None, 0))
batch_x = jnp.zeros((32, 784))
batch_logits = mlp_batch(params, batch_x)
print(batch_logits.shape) # (32, 10)
loss + grad
def cross_entropy_loss(params, x, y):
'''y: one-hot (10,)'''
logits = mlp_apply(params, x)
log_probs = jax.nn.log_softmax(logits)
return -jnp.sum(y * log_probs)
def batch_loss(params, X, Y):
'''X: (B, 784), Y: (B, 10)'''
losses = jax.vmap(cross_entropy_loss, in_axes=(None, 0, 0))(params, X, Y)
return losses.mean()
# gradient — params 와 같은 모양의 gradient pytree
grads = jax.grad(batch_loss)(params, batch_x, batch_y)
print(jax.tree.map(lambda x: x.shape, grads))
# 같은 구조: list of {"W": ..., "b": ...}
학습 step
@jax.jit
def train_step(params, X, Y, lr):
loss, grads = jax.value_and_grad(batch_loss)(params, X, Y)
new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
return new_params, loss
# 학습 loop
for step in range(1000):
params, loss = train_step(params, batch_x, batch_y, 0.001)
parameter manipulation
# Weight decay — 모든 W 만 (b 제외)
def add_weight_decay(grads, params, wd=1e-4):
def decay(g, p):
# leaf level 에서는 W 와 b 구분 못함, dict 단위로 작업
return g
# 깔끔히 하려면 — layer 별
new_grads = []
for g, p in zip(grads, params):
new_grads.append({
"W": g["W"] + wd * p["W"],
"b": g["b"], # bias 는 decay 안 함
})
return new_grads
# Gradient clipping
def clip_grads(grads, max_norm=1.0):
leaves = jax.tree.leaves(grads)
total_norm = jnp.sqrt(sum(jnp.sum(l**2) for l in leaves))
scale = jnp.minimum(1.0, max_norm / (total_norm + 1e-6))
return jax.tree.map(lambda g: g * scale, grads)
🌳 pytree NN 의 전체 그림
(1) params 는 pytree (list, dict, dataclass — 자유). (2) forward 는 함수 (params, x) → output. (3) loss 는 함수 (params, X, Y) → scalar. (4) gradient 는 jax.grad 로 자동, params 와 같은 모양의 pytree. (5) update 는 jax.tree.map 로 element-wise. (6) jit 으로 compile. 모든 JAX NN 코드의 골격이 이 6 단계.
Track 10 에서 — Flax NNX 나 Equinox 같은 라이브러리들이 위 패턴을 어떻게 깔끔히 wrapping 하는지 봐. 그러나 underlying 은 — 항상 pytree + 함수.