C.W.K.
Stream
Lesson 01 of 05 · published

JAX 학습 루프가 명시적인 이유

~8 min · training, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

PyTorch Lightning, Keras 의 model.fit(...) 같은 high-level wrapper — JAX 에는 없어. 모든 train step 을 손으로 작성. 처음엔 답답한데 — 의도된 거.

JAX 식 학습 루프의 모양

@jax.jit
def train_step(state, batch):
    x, y = batch

    def loss_fn(params):
        pred = model.apply(params, x)
        loss = compute_loss(pred, y)
        metrics = {"acc": accuracy(pred, y)}
        return loss, metrics

    (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params)
    new_params = optax.apply_updates(state.params, updates)

    new_state = state.replace(
        params=new_params,
        opt_state=new_opt_state,
        step=state.step + 1,
    )
    return new_state, loss, metrics

# 사용자가 직접 loop
for batch in dataloader:
    state, loss, metrics = train_step(state, batch)
    if state.step % 100 == 0:
        print(f"step {state.step}: loss={loss:.4f}")

30 줄. 모든 게 보임. 어디 magic 있는 곳 없음.

왜 이게 좋은가?

  • 모든 step 이 가시적: gradient, optimizer state, parameter update — 다 손에 잡힘. Trainer.fit() 에서 안 보이던 부분이 다 노출.
  • 커스터마이징 자유: gradient clip, custom 갱신 규칙, 다양한 schedule — 모두 같은 30 줄 안에 추가. wrapper API 의 한계 없음.
  • 디버깅 단순: print, assert, breakpoint — 어디든 자유. magic 한 callback 시스템 없음.
  • JIT 명확: train_step 함수가 정확히 무엇을 jit 하는지 명시. compile 비용도 가시적.

비교 — PyTorch Lightning

# PyTorch Lightning
class MyModel(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)
        loss = F.cross_entropy(pred, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)

trainer = pl.Trainer(max_epochs=10, gpus=4)
trainer.fit(model, dataloader)

편함 — 단, 무엇이 어떻게 돌아가는지 알려면 — Lightning 의 source 를 파야 함. JAX 는 — 그 30 줄이 곧 source.

고수준 wrapper 가 필요할 때

Flax 의 nnx.training, chex, clu 같은 라이브러리가 — 일부 boilerplate 줄여줌. 그러나 — 핵심은 같은 explicit 패턴, 보조 도구만 추가. PyTorch Lightning 같은 통합 wrapper 는 JAX 에 없음 (의도적).

🛠 JAX 학습 코드의 mantra

"Show me the loop." JAX 코드를 받으면 — train_step 함수 한 개 + loop 한 개. 30 줄로 끝. wrapper 가 늘어날수록 — 그 코드가 JAX 답지 않음을 의심해 봐. 학습 코드가 짧아지는 것보다 — 명료한 게 더 가치 있다는 게 JAX 공동체의 가치관.

한 가지 — explicit 학습 코드는 — 처음 익히는 데 학습 곡선이 있어. 몇 번 짜 보면 — 패턴이 눈에 들어와서 — 새 task 마다 빠르게 응용 가능. PyTorch Lightning 의 callback 들 외우는 것보다 — 더 transferable 한 지식이라고 생각.

Code

import jax
import jax.numpy as jnp

# The pattern: forward → loss → grad → update → repeat
def train_step(params, x, y, lr=0.01):
    # 1. Define loss as a function of params
    def loss_fn(params):
        predictions = model_forward(params, x)
        return jnp.mean((predictions - y) ** 2)

    # 2. Compute loss and gradients simultaneously
    loss, grads = jax.value_and_grad(loss_fn)(params)

    # 3. Update parameters (simple SGD)
    new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)

    return new_params, loss

# 4. JIT-compile for speed
train_step_jit = jax.jit(train_step)

# 5. Training loop
for epoch in range(num_epochs):
    for batch in data_loader:
        params, loss = train_step_jit(params, *batch)

External links

Exercise

JAX 의 'loop 직접 작성' 과 PyTorch Lightning 의 'configure_optimizer + training_step' 비교. 5-bullet pro/con. JAX 가 포기하라고 하는 것 — 그리고 돌려주는 것 — 한 줄 진술.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.