C.W.K.
Stream
Lesson 05 of 05 · published

Scan 으로 효율적 학습 + Orbax checkpoint

~11 min · training, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

학습 코드의 두 가지 production 필수 — 빠른 step 반복 (scan) + 안정적 checkpoint (Orbax).

jax.lax.scan 으로 학습 step 반복

Python for-loop — 매 step 마다 host ↔ device 통신 + Python overhead. scan — 모든 step 을 single jit IR 로 묶어 가속기에서 한 번에.

@jax.jit
def train_n_steps(state, batches, n):
    '''n 개 step 을 scan 으로 한 번에'''
    def body(state, batch):
        new_state, loss, metrics = train_step(state, batch)
        return new_state, (loss, metrics)

    final_state, (losses, metrics_arr) = jax.lax.scan(body, state, batches)
    return final_state, losses, metrics_arr

# 사용
N_STEPS = 1000
all_batches = make_batches(N_STEPS)   # shape: (N_STEPS, B, ...)

state, losses, metrics = train_n_steps(state, all_batches, N_STEPS)
print(f"평균 loss: {losses.mean():.4f}")

1000 step 을 한 번의 jit 호출로 — host overhead 거의 0. 단점: 모든 batch 가 미리 device 에 있어야 함 (큰 dataset 이면 chunk).

현실적 패턴 — chunk 단위로 scan, chunk 사이는 Python:

CHUNK = 100   # 100 step 단위로 scan

for epoch in range(num_epochs):
    for chunk_idx in range(50):   # 50 chunk = 5000 step
        chunk_batches = get_next_chunk(CHUNK)
        state, losses, metrics = train_n_steps(state, chunk_batches, CHUNK)
        print(f"chunk {chunk_idx}: avg loss = {losses.mean():.4f}")

Orbax checkpoint

큰 모델 학습 — 몇 시간 ~ 며칠. 중간에 죽으면 처음부터 → 절대 안 됨. Orbax 가 표준.

pip install orbax-checkpoint
import orbax.checkpoint as ocp
from etils import epath

# checkpoint manager 설정
ckpt_dir = epath.Path("/tmp/my_model_ckpts")
options = ocp.CheckpointManagerOptions(
    save_interval_steps=500,    # 500 step 마다 저장
    max_to_keep=3,              # 최근 3 개만 보관
)
mgr = ocp.CheckpointManager(
    ckpt_dir,
    item_names=("state",),
    options=options,
)

# 저장
mgr.save(
    step=state.step,
    args=ocp.args.Composite(state=ocp.args.StandardSave(state)),
)
mgr.wait_until_finished()

# 복원 (latest)
restored = mgr.restore(
    mgr.latest_step(),
    args=ocp.args.Composite(state=ocp.args.StandardRestore(state)),
)
state = restored["state"]
print(f"복원: step {state.step}")

완전한 학습 루프 + checkpoint

def train_with_resume(initial_state, resume=True):
    state = initial_state

    if resume and mgr.latest_step() is not None:
        restored = mgr.restore(
            mgr.latest_step(),
            args=ocp.args.Composite(state=ocp.args.StandardRestore(state)),
        )
        state = restored["state"]
        print(f"resumed from step {state.step}")

    for epoch in range(num_epochs):
        for chunk_idx in range(num_chunks):
            chunk_batches = get_next_chunk(CHUNK)
            state, losses, metrics = train_n_steps(state, chunk_batches, CHUNK)

            # checkpoint 자동 저장
            mgr.save(
                step=state.step,
                args=ocp.args.Composite(state=ocp.args.StandardSave(state)),
            )

    mgr.wait_until_finished()
    return state

process 가 죽고 다시 시작 — 마지막 checkpoint 부터 자동 복원. step 카운터, optimizer state, params 모두.

multi-host checkpoint

큰 model — multi-host 학습. 모든 host 가 같은 checkpoint 를 봐야 함. Orbax 가 자동 처리:

# 모든 host 가 같은 mgr instance 만들고, save 호출
# Orbax 가 host 0 만 disk write, 나머지는 wait
# distributed file system (GCS, S3) 권장

🎯 학습 안정성 체크리스트

(1) scan + jit — host overhead 제거. (2) Orbax checkpoint — 매 N step 자동. (3) max_to_keep — disk 가득 안 차게. (4) wait_until_finished — exit 전 호출. (5) resume from latest — process 재시작 자동. (6) checkpoint 안에 random key 도 — 학습 재현성 유지. (7) test resume on toy run — 첫 코드 짤 때 한 번 실험.

이 패턴이 — Llama / Gemini / GPT 학습 코드의 표준 골격. 모델 사이즈가 1000 배 커도 — 같은 7 가지 구성에 같은 checkpoint pattern.

Code

import jax
import jax.numpy as jnp

# Instead of a Python loop for multiple gradient steps:
# for _ in range(100):
#     params, loss = train_step(params, batch)

# Use scan for a compiled loop:
def scan_train(params, opt_state, batches):
    """Run multiple training steps as a single compiled scan."""
    def step_fn(carry, batch):
        params, opt_state = carry
        x, y = batch

        def loss_fn(p):
            logits = forward(p, x)
            return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(
                logits, y))

        loss, grads = jax.value_and_grad(loss_fn)(params)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return (params, opt_state), loss

    (final_params, final_opt_state), losses = jax.lax.scan(
        step_fn,
        init=(params, opt_state),
        xs=batches,  # stacked batch arrays: (num_steps, batch_size, ...)
    )
    return final_params, final_opt_state, losses

# This compiles to a SINGLE XLA while loop — much faster than
# running train_step in a Python for loop inside jit
import orbax.checkpoint as ocp

# Save a checkpoint
checkpointer = ocp.StandardCheckpointer()

# Save parameters (works with any pytree)
save_path = '/tmp/my_model/step_1000'
checkpointer.save(save_path, params)

# Restore parameters
restored_params = checkpointer.restore(save_path)

# For Flax NNX models, save the state:
state = nnx.state(model)
checkpointer.save('/tmp/my_model/step_2000', state)

# Restore:
restored_state = checkpointer.restore('/tmp/my_model/step_2000')
nnx.update(model, restored_state)

# CheckpointManager for managing multiple checkpoints
manager = ocp.CheckpointManager(
    '/tmp/checkpoints',
    options=ocp.CheckpointManagerOptions(
        max_to_keep=3,           # keep only 3 most recent
        save_interval_steps=500, # save every 500 steps
    ),
)

# In training loop:
for step in range(10000):
    params, loss = train_step(params, batch)
    manager.save(step, args=ocp.args.StandardSave(params))

# Restore latest:
step = manager.latest_step()
params = manager.restore(step)

External links

Exercise

Python for-loop trainer 를 jax.lax.scan over batches 로 교체. 1000 step 의 before vs after 측정. Orbax wire 해서 100 step 마다 params checkpoint. checkpoint 에서 resume, continuity 확인. research script 와 production trainer 의 분기점.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.