C.W.K.
Stream
Lesson 03 of 05 · published

Learning Rate Schedule

~8 min · training, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

큰 모델 학습에서 — learning rate schedule 이 학습 안정성과 최종 성능을 좌우. Optax 가 표준 schedule 을 다 제공.

주요 schedule

import optax
import matplotlib.pyplot as plt

# 1. Constant
sched_const = optax.constant_schedule(1e-3)

# 2. Linear warmup
sched_warmup = optax.linear_schedule(
    init_value=0.0,
    end_value=1e-3,
    transition_steps=1000,   # 1000 step 동안 0 → 1e-3
)

# 3. Cosine decay (warm restart 가능)
sched_cosine = optax.cosine_decay_schedule(
    init_value=1e-3,
    decay_steps=10_000,
    alpha=0.1,   # 최저 = init * 0.1
)

# 4. Warmup + cosine (현대 표준)
sched = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=1e-3,
    warmup_steps=1000,
    decay_steps=10_000,
    end_value=1e-5,
)

# 5. Exponential decay
sched_exp = optax.exponential_decay(
    init_value=1e-3,
    transition_steps=1000,
    decay_rate=0.5,
)

# 6. Polynomial
sched_poly = optax.polynomial_schedule(
    init_value=1e-3,
    end_value=1e-5,
    power=2.0,
    transition_steps=10_000,
)

visualize

steps = jnp.arange(15_000)
lrs = jnp.array([sched(s) for s in steps])

plt.plot(steps, lrs)
plt.xlabel("step"); plt.ylabel("learning rate")
plt.show()

warmup_cosine_decay_schedule 의 모양:

peak ──╮
       │   ╲ (cosine)
       ╱    ╲
0 ────╯      ╲___ end_value
      ↑      ↑
   warmup    decay 끝

학습 코드 통합

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=3e-4,
    warmup_steps=1000,
    decay_steps=100_000,
    end_value=3e-5,
)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.scale_by_adam(),
    optax.scale_by_schedule(schedule),
    optax.scale(-1.0),
)

# 학습 루프 — schedule 이 자동 적용
@jax.jit
def step(params, opt_state, batch):
    grads = jax.grad(loss_fn)(params, *batch)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    return optax.apply_updates(params, updates), opt_state

schedule 의 step 카운터 — opt_state 안에 자동 보존. 사용자가 따로 추적 안 해도 됨.

합성 schedule

# 여러 단계 — 처음엔 warmup, 그 후 cosine, 그 후 constant
sched = optax.join_schedules(
    schedules=[
        optax.linear_schedule(0.0, 3e-4, 1000),   # warmup
        optax.cosine_decay_schedule(3e-4, 50_000, alpha=0.1),  # decay
        optax.constant_schedule(3e-5),             # 끝까지 유지
    ],
    boundaries=[1000, 51_000],
)

💡 schedule 디버깅

새 schedule 을 학습에 쓰기 전 — 항상 plot. steps = jnp.arange(N); lrs = schedule(steps) 로 vector 화 호출 가능. peak 값, warmup 길이, decay 모양 — 의도와 일치하는지 한 번 확인. 잘못된 schedule 로 학습 망가지는 게 가장 흔한 trap.

현대 LLM 학습의 표준 — warmup_cosine_decay 또는 warmup + linear decay. peak lr 은 model 사이즈와 batch size 의 함수 (Chinchilla / Llama 식 scaling rule).

Code

import optax
import jax.numpy as jnp

# Cosine decay: starts at init_value, decays to alpha over decay_steps
schedule = optax.cosine_decay_schedule(
    init_value=1e-3,
    decay_steps=10000,
    alpha=0.0,        # minimum learning rate
)

# Check values at different steps
print(f"Step 0: {schedule(0):.6f}")      # 0.001000
print(f"Step 5000: {schedule(5000):.6f}") # 0.000500
print(f"Step 10000: {schedule(10000):.6f}")# 0.000000

# Warmup + cosine decay (very common in practice)
schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,        # start from 0
    peak_value=1e-3,       # warm up to this
    warmup_steps=1000,     # linear warmup for 1000 steps
    decay_steps=50000,     # total steps including warmup
    end_value=1e-5,        # minimum LR at end
)

# Use schedule with an optimizer
optimizer = optax.adamw(learning_rate=schedule, weight_decay=0.01)

# Or compose with chain
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=schedule),
)
# Exponential decay
schedule = optax.exponential_decay(
    init_value=1e-3,
    transition_steps=1000,
    decay_rate=0.96,
)

# Piecewise constant (manual step schedule)
schedule = optax.piecewise_constant_schedule(
    init_value=1e-3,
    boundaries_and_scales={
        5000: 0.1,   # multiply LR by 0.1 at step 5000
        8000: 0.1,   # multiply again at step 8000
    }
)

# Warm restarts (SGDR)
schedule = optax.sgdr_schedule([
    dict(init_value=1e-3, peak_value=1e-3,
         decay_steps=5000, warmup_steps=500),
    dict(init_value=1e-3, peak_value=5e-4,
         decay_steps=5000, warmup_steps=500),
])

External links

Exercise

optax.warmup_cosine_decay_schedule 로 warmup + cosine schedule 구성. 1000 step 의 LR plot. optax.adamw 호출에 wire. 모든 modern training run 의 정석 패턴.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.