학습 코드의 두 가지 production 필수 — 빠른 step 반복 (scan) + 안정적 checkpoint (Orbax).
jax.lax.scan 으로 학습 step 반복
Python for-loop — 매 step 마다 host ↔ device 통신 + Python overhead. scan — 모든 step 을 single jit IR 로 묶어 가속기에서 한 번에.
@jax.jit
def train_n_steps(state, batches, n):
'''n 개 step 을 scan 으로 한 번에'''
def body(state, batch):
new_state, loss, metrics = train_step(state, batch)
return new_state, (loss, metrics)
final_state, (losses, metrics_arr) = jax.lax.scan(body, state, batches)
return final_state, losses, metrics_arr
# 사용
N_STEPS = 1000
all_batches = make_batches(N_STEPS) # shape: (N_STEPS, B, ...)
state, losses, metrics = train_n_steps(state, all_batches, N_STEPS)
print(f"평균 loss: {losses.mean():.4f}")
1000 step 을 한 번의 jit 호출로 — host overhead 거의 0. 단점: 모든 batch 가 미리 device 에 있어야 함 (큰 dataset 이면 chunk).
현실적 패턴 — chunk 단위로 scan, chunk 사이는 Python:
CHUNK = 100 # 100 step 단위로 scan
for epoch in range(num_epochs):
for chunk_idx in range(50): # 50 chunk = 5000 step
chunk_batches = get_next_chunk(CHUNK)
state, losses, metrics = train_n_steps(state, chunk_batches, CHUNK)
print(f"chunk {chunk_idx}: avg loss = {losses.mean():.4f}")
Orbax checkpoint
큰 모델 학습 — 몇 시간 ~ 며칠. 중간에 죽으면 처음부터 → 절대 안 됨. Orbax 가 표준.
pip install orbax-checkpoint
import orbax.checkpoint as ocp
from etils import epath
# checkpoint manager 설정
ckpt_dir = epath.Path("/tmp/my_model_ckpts")
options = ocp.CheckpointManagerOptions(
save_interval_steps=500, # 500 step 마다 저장
max_to_keep=3, # 최근 3 개만 보관
)
mgr = ocp.CheckpointManager(
ckpt_dir,
item_names=("state",),
options=options,
)
# 저장
mgr.save(
step=state.step,
args=ocp.args.Composite(state=ocp.args.StandardSave(state)),
)
mgr.wait_until_finished()
# 복원 (latest)
restored = mgr.restore(
mgr.latest_step(),
args=ocp.args.Composite(state=ocp.args.StandardRestore(state)),
)
state = restored["state"]
print(f"복원: step {state.step}")
완전한 학습 루프 + checkpoint
def train_with_resume(initial_state, resume=True):
state = initial_state
if resume and mgr.latest_step() is not None:
restored = mgr.restore(
mgr.latest_step(),
args=ocp.args.Composite(state=ocp.args.StandardRestore(state)),
)
state = restored["state"]
print(f"resumed from step {state.step}")
for epoch in range(num_epochs):
for chunk_idx in range(num_chunks):
chunk_batches = get_next_chunk(CHUNK)
state, losses, metrics = train_n_steps(state, chunk_batches, CHUNK)
# checkpoint 자동 저장
mgr.save(
step=state.step,
args=ocp.args.Composite(state=ocp.args.StandardSave(state)),
)
mgr.wait_until_finished()
return state
process 가 죽고 다시 시작 — 마지막 checkpoint 부터 자동 복원. step 카운터, optimizer state, params 모두.
multi-host checkpoint
큰 model — multi-host 학습. 모든 host 가 같은 checkpoint 를 봐야 함. Orbax 가 자동 처리:
# 모든 host 가 같은 mgr instance 만들고, save 호출
# Orbax 가 host 0 만 disk write, 나머지는 wait
# distributed file system (GCS, S3) 권장
🎯 학습 안정성 체크리스트
(1) scan + jit — host overhead 제거. (2) Orbax checkpoint — 매 N step 자동. (3) max_to_keep — disk 가득 안 차게. (4) wait_until_finished — exit 전 호출. (5) resume from latest — process 재시작 자동. (6) checkpoint 안에 random key 도 — 학습 재현성 유지. (7) test resume on toy run — 첫 코드 짤 때 한 번 실험.
이 패턴이 — Llama / Gemini / GPT 학습 코드의 표준 골격. 모델 사이즈가 1000 배 커도 — 같은 7 가지 구성에 같은 checkpoint pattern.