큰 모델 학습에서 — learning rate schedule 이 학습 안정성과 최종 성능을 좌우. Optax 가 표준 schedule 을 다 제공.
주요 schedule
import optax
import matplotlib.pyplot as plt
# 1. Constant
sched_const = optax.constant_schedule(1e-3)
# 2. Linear warmup
sched_warmup = optax.linear_schedule(
init_value=0.0,
end_value=1e-3,
transition_steps=1000, # 1000 step 동안 0 → 1e-3
)
# 3. Cosine decay (warm restart 가능)
sched_cosine = optax.cosine_decay_schedule(
init_value=1e-3,
decay_steps=10_000,
alpha=0.1, # 최저 = init * 0.1
)
# 4. Warmup + cosine (현대 표준)
sched = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=1e-3,
warmup_steps=1000,
decay_steps=10_000,
end_value=1e-5,
)
# 5. Exponential decay
sched_exp = optax.exponential_decay(
init_value=1e-3,
transition_steps=1000,
decay_rate=0.5,
)
# 6. Polynomial
sched_poly = optax.polynomial_schedule(
init_value=1e-3,
end_value=1e-5,
power=2.0,
transition_steps=10_000,
)
visualize
steps = jnp.arange(15_000)
lrs = jnp.array([sched(s) for s in steps])
plt.plot(steps, lrs)
plt.xlabel("step"); plt.ylabel("learning rate")
plt.show()
warmup_cosine_decay_schedule 의 모양:
peak ──╮
│ ╲ (cosine)
╱ ╲
0 ────╯ ╲___ end_value
↑ ↑
warmup decay 끝
학습 코드 통합
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=3e-4,
warmup_steps=1000,
decay_steps=100_000,
end_value=3e-5,
)
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.scale_by_adam(),
optax.scale_by_schedule(schedule),
optax.scale(-1.0),
)
# 학습 루프 — schedule 이 자동 적용
@jax.jit
def step(params, opt_state, batch):
grads = jax.grad(loss_fn)(params, *batch)
updates, opt_state = optimizer.update(grads, opt_state, params)
return optax.apply_updates(params, updates), opt_state
schedule 의 step 카운터 — opt_state 안에 자동 보존. 사용자가 따로 추적 안 해도 됨.
합성 schedule
# 여러 단계 — 처음엔 warmup, 그 후 cosine, 그 후 constant
sched = optax.join_schedules(
schedules=[
optax.linear_schedule(0.0, 3e-4, 1000), # warmup
optax.cosine_decay_schedule(3e-4, 50_000, alpha=0.1), # decay
optax.constant_schedule(3e-5), # 끝까지 유지
],
boundaries=[1000, 51_000],
)
💡 schedule 디버깅
새 schedule 을 학습에 쓰기 전 — 항상 plot. steps = jnp.arange(N); lrs = schedule(steps) 로 vector 화 호출 가능. peak 값, warmup 길이, decay 모양 — 의도와 일치하는지 한 번 확인. 잘못된 schedule 로 학습 망가지는 게 가장 흔한 trap.
현대 LLM 학습의 표준 — warmup_cosine_decay 또는 warmup + linear decay. peak lr 은 model 사이즈와 batch size 의 함수 (Chinchilla / Llama 식 scaling rule).