C.W.K.
Stream
Lesson 04 of 05 · published

완전한 학습 루프

~9 min · training, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

지금까지 배운 것 — model + Optax + schedule + checkpoint — 합쳐 완전한 미니 trainer. 실제 production 에서 빼먹는 거 없는 모양.

import jax
import jax.numpy as jnp
from jax import random
import optax
from flax import nnx
from flax import struct

# ============ 1. Model ============
class MLP(nnx.Module):
    def __init__(self, dims, *, rngs):
        self.layers = [
            nnx.Linear(dims[i], dims[i+1], rngs=rngs)
            for i in range(len(dims) - 1)
        ]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x)

# ============ 2. TrainState ============
@struct.dataclass
class TrainState:
    params: any
    opt_state: any
    step: int
    key: any

# ============ 3. 초기화 ============
key = random.PRNGKey(42)
key, init_key, train_key = random.split(key, 3)

# model
model = MLP([784, 256, 128, 10], rngs=nnx.Rngs(init_key))
graphdef, params = nnx.split(model, nnx.Param)

# optimizer
schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=1e-3,
    warmup_steps=500,
    decay_steps=5000,
    end_value=1e-5,
)
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=schedule, weight_decay=1e-4),
)
opt_state = optimizer.init(params)

state = TrainState(
    params=params,
    opt_state=opt_state,
    step=0,
    key=train_key,
)

# ============ 4. loss 와 train step ============
def loss_fn(params, x, y, key):
    model = nnx.merge(graphdef, params)
    logits = model(x)
    log_probs = jax.nn.log_softmax(logits)
    loss = -jnp.mean(jnp.sum(y * log_probs, axis=-1))
    acc = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(y, axis=-1))
    return loss, {"acc": acc}

@jax.jit
def train_step(state, batch):
    x, y = batch
    key, subkey = random.split(state.key)

    (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(
        state.params, x, y, subkey,
    )

    updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params)
    new_params = optax.apply_updates(state.params, updates)

    new_state = state.replace(
        params=new_params,
        opt_state=new_opt_state,
        step=state.step + 1,
        key=key,
    )
    return new_state, loss, metrics

# ============ 5. 평가 step ============
@jax.jit
def eval_step(state, batch):
    x, y = batch
    model = nnx.merge(graphdef, state.params)
    logits = model(x)
    acc = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(y, axis=-1))
    return acc

# ============ 6. 학습 루프 ============
import time

def make_batch(key, batch_size=128):
    '''더미 batch — 실제론 dataset'''
    k1, k2 = random.split(key)
    x = random.normal(k1, (batch_size, 784))
    y = jax.nn.one_hot(random.randint(k2, (batch_size,), 0, 10), 10)
    return x, y

t = time.time()
for epoch in range(10):
    # 학습
    for batch_idx in range(50):
        batch = make_batch(random.fold_in(key, epoch * 1000 + batch_idx))
        state, loss, metrics = train_step(state, batch)

    # epoch 평가
    val_batch = make_batch(random.fold_in(key, 999_999), batch_size=512)
    val_acc = eval_step(state, val_batch)

    current_lr = schedule(state.step)
    print(f"epoch {epoch:2d}  step {state.step:4d}  "
          f"loss={loss:.4f}  acc={metrics['acc']:.3f}  "
          f"val_acc={val_acc:.3f}  lr={current_lr:.6f}")

print(f"\n총 학습 시간: {time.time()-t:.1f}s")

이게 — production-ready 의 골격. 추가할 만한 것:

  • 실제 data loader (PyTorch DataLoader 호환 또는 grain / tf.data)
  • wandb / tensorboard logging
  • checkpoint 저장 (Track 11-5)
  • multi-GPU 분산 (Track 7)
  • mixed precision (Optax 가 helper 제공)
  • gradient accumulation (effective batch size 키우기)

📐 production 코드의 7 가지 구성

(1) model definition — pytree 또는 nnx/eqx Module. (2) train_state — params + opt_state + step + key. (3) loss + metrics — value_and_grad with has_aux. (4) train_step — jit 으로 compile. (5) eval_step — model.eval() 같은 mode 분리. (6) schedule — warmup + cosine. (7) optimizer — optax.chain. 어느 model 이든 — 이 7 개 구성이 같은 모양.

이 패턴 — Llama-3 70B 학습 코드도, 작은 MLP 도, 기본 모양은 동일. 차이는 — model 의 크기, 데이터 양, multi-host 분산 코드 추가.

Code

from flax import nnx
import jax
import jax.numpy as jnp
import optax

# --- Model Definition ---
class Classifier(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
        self.linear1 = nnx.Linear(64 * 5 * 5, 256, rngs=rngs)
        self.linear2 = nnx.Linear(256, 10, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)

    def __call__(self, x):
        x = nnx.relu(self.conv1(x))
        x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nnx.relu(self.conv2(x))
        x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape(x.shape[0], -1)
        x = nnx.relu(self.linear1(x))
        x = self.dropout(x)
        return self.linear2(x)

# --- Setup ---
model = Classifier(rngs=nnx.Rngs(0))

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=1e-3,
    warmup_steps=500,
    decay_steps=10000,
)

tx = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=schedule),
)
optimizer = nnx.Optimizer(model, tx)

# --- Training Step ---
@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        logits = model(x)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
        return jnp.mean(loss)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)
    return loss

# --- Evaluation Step ---
@nnx.jit
def eval_step(model, x, y):
    logits = model(x)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == y)
    return accuracy

# --- Training Loop ---
num_epochs = 10
for epoch in range(num_epochs):
    # Training
    train_loss = 0.0
    num_batches = 0
    for x_batch, y_batch in train_loader:
        loss = train_step(model, optimizer, x_batch, y_batch)
        train_loss += loss
        num_batches += 1

    avg_loss = train_loss / num_batches

    # Evaluation
    total_acc = 0.0
    eval_batches = 0
    for x_val, y_val in val_loader:
        acc = eval_step(model, x_val, y_val)
        total_acc += acc
        eval_batches += 1

    avg_acc = total_acc / eval_batches
    print(f"Epoch {epoch}: loss={avg_loss:.4f}, val_acc={avg_acc:.4f}")

External links

Exercise

결합: model (Flax NNX 또는 Equinox), loss, optax optimizer, jit'd train_step, 단순 data loader, 100-step loop. 10 step 마다 학습 loss 출력. 마지막 params 저장. minimum-viable JAX trainer template.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.