지금까지 배운 것 — model + Optax + schedule + checkpoint — 합쳐 완전한 미니 trainer. 실제 production 에서 빼먹는 거 없는 모양.
import jax
import jax.numpy as jnp
from jax import random
import optax
from flax import nnx
from flax import struct
# ============ 1. Model ============
class MLP(nnx.Module):
def __init__(self, dims, *, rngs):
self.layers = [
nnx.Linear(dims[i], dims[i+1], rngs=rngs)
for i in range(len(dims) - 1)
]
def __call__(self, x):
for layer in self.layers[:-1]:
x = jax.nn.relu(layer(x))
return self.layers[-1](x)
# ============ 2. TrainState ============
@struct.dataclass
class TrainState:
params: any
opt_state: any
step: int
key: any
# ============ 3. 초기화 ============
key = random.PRNGKey(42)
key, init_key, train_key = random.split(key, 3)
# model
model = MLP([784, 256, 128, 10], rngs=nnx.Rngs(init_key))
graphdef, params = nnx.split(model, nnx.Param)
# optimizer
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=1e-3,
warmup_steps=500,
decay_steps=5000,
end_value=1e-5,
)
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(learning_rate=schedule, weight_decay=1e-4),
)
opt_state = optimizer.init(params)
state = TrainState(
params=params,
opt_state=opt_state,
step=0,
key=train_key,
)
# ============ 4. loss 와 train step ============
def loss_fn(params, x, y, key):
model = nnx.merge(graphdef, params)
logits = model(x)
log_probs = jax.nn.log_softmax(logits)
loss = -jnp.mean(jnp.sum(y * log_probs, axis=-1))
acc = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(y, axis=-1))
return loss, {"acc": acc}
@jax.jit
def train_step(state, batch):
x, y = batch
key, subkey = random.split(state.key)
(loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(
state.params, x, y, subkey,
)
updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params)
new_params = optax.apply_updates(state.params, updates)
new_state = state.replace(
params=new_params,
opt_state=new_opt_state,
step=state.step + 1,
key=key,
)
return new_state, loss, metrics
# ============ 5. 평가 step ============
@jax.jit
def eval_step(state, batch):
x, y = batch
model = nnx.merge(graphdef, state.params)
logits = model(x)
acc = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(y, axis=-1))
return acc
# ============ 6. 학습 루프 ============
import time
def make_batch(key, batch_size=128):
'''더미 batch — 실제론 dataset'''
k1, k2 = random.split(key)
x = random.normal(k1, (batch_size, 784))
y = jax.nn.one_hot(random.randint(k2, (batch_size,), 0, 10), 10)
return x, y
t = time.time()
for epoch in range(10):
# 학습
for batch_idx in range(50):
batch = make_batch(random.fold_in(key, epoch * 1000 + batch_idx))
state, loss, metrics = train_step(state, batch)
# epoch 평가
val_batch = make_batch(random.fold_in(key, 999_999), batch_size=512)
val_acc = eval_step(state, val_batch)
current_lr = schedule(state.step)
print(f"epoch {epoch:2d} step {state.step:4d} "
f"loss={loss:.4f} acc={metrics['acc']:.3f} "
f"val_acc={val_acc:.3f} lr={current_lr:.6f}")
print(f"\n총 학습 시간: {time.time()-t:.1f}s")
이게 — production-ready 의 골격. 추가할 만한 것:
- 실제 data loader (PyTorch DataLoader 호환 또는 grain / tf.data)
- wandb / tensorboard logging
- checkpoint 저장 (Track 11-5)
- multi-GPU 분산 (Track 7)
- mixed precision (Optax 가 helper 제공)
- gradient accumulation (effective batch size 키우기)
📐 production 코드의 7 가지 구성
(1) model definition — pytree 또는 nnx/eqx Module. (2) train_state — params + opt_state + step + key. (3) loss + metrics — value_and_grad with has_aux. (4) train_step — jit 으로 compile. (5) eval_step — model.eval() 같은 mode 분리. (6) schedule — warmup + cosine. (7) optimizer — optax.chain. 어느 model 이든 — 이 7 개 구성이 같은 모양.
이 패턴 — Llama-3 70B 학습 코드도, 작은 MLP 도, 기본 모양은 동일. 차이는 — model 의 크기, 데이터 양, multi-host 분산 코드 추가.