PyTorch Lightning, Keras 의 model.fit(...) 같은 high-level wrapper — JAX 에는 없어. 모든 train step 을 손으로 작성. 처음엔 답답한데 — 의도된 거.
JAX 식 학습 루프의 모양
@jax.jit
def train_step(state, batch):
x, y = batch
def loss_fn(params):
pred = model.apply(params, x)
loss = compute_loss(pred, y)
metrics = {"acc": accuracy(pred, y)}
return loss, metrics
(loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params)
new_params = optax.apply_updates(state.params, updates)
new_state = state.replace(
params=new_params,
opt_state=new_opt_state,
step=state.step + 1,
)
return new_state, loss, metrics
# 사용자가 직접 loop
for batch in dataloader:
state, loss, metrics = train_step(state, batch)
if state.step % 100 == 0:
print(f"step {state.step}: loss={loss:.4f}")
30 줄. 모든 게 보임. 어디 magic 있는 곳 없음.
왜 이게 좋은가?
- 모든 step 이 가시적: gradient, optimizer state, parameter update — 다 손에 잡힘. Trainer.fit() 에서 안 보이던 부분이 다 노출.
- 커스터마이징 자유: gradient clip, custom 갱신 규칙, 다양한 schedule — 모두 같은 30 줄 안에 추가. wrapper API 의 한계 없음.
- 디버깅 단순: print, assert, breakpoint — 어디든 자유. magic 한 callback 시스템 없음.
- JIT 명확: train_step 함수가 정확히 무엇을 jit 하는지 명시. compile 비용도 가시적.
비교 — PyTorch Lightning
# PyTorch Lightning
class MyModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
pred = self.model(x)
loss = F.cross_entropy(pred, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=1e-3)
trainer = pl.Trainer(max_epochs=10, gpus=4)
trainer.fit(model, dataloader)
편함 — 단, 무엇이 어떻게 돌아가는지 알려면 — Lightning 의 source 를 파야 함. JAX 는 — 그 30 줄이 곧 source.
고수준 wrapper 가 필요할 때
Flax 의 nnx.training, chex, clu 같은 라이브러리가 — 일부 boilerplate 줄여줌. 그러나 — 핵심은 같은 explicit 패턴, 보조 도구만 추가. PyTorch Lightning 같은 통합 wrapper 는 JAX 에 없음 (의도적).
🛠 JAX 학습 코드의 mantra
"Show me the loop." JAX 코드를 받으면 — train_step 함수 한 개 + loop 한 개. 30 줄로 끝. wrapper 가 늘어날수록 — 그 코드가 JAX 답지 않음을 의심해 봐. 학습 코드가 짧아지는 것보다 — 명료한 게 더 가치 있다는 게 JAX 공동체의 가치관.
한 가지 — explicit 학습 코드는 — 처음 익히는 데 학습 곡선이 있어. 몇 번 짜 보면 — 패턴이 눈에 들어와서 — 새 task 마다 빠르게 응용 가능. PyTorch Lightning 의 callback 들 외우는 것보다 — 더 transferable 한 지식이라고 생각.