Optax 는 — JAX 의 사실상 표준 optimizer 라이브러리. 핵심 아이디어 — optimizer 가 monolithic class 가 아니라 — 작은 transformation 들의 합성.
pip install optax
가장 단순한 사용
import optax
import jax
import jax.numpy as jnp
# AdamW optimizer
optimizer = optax.adamw(learning_rate=1e-3, weight_decay=1e-4)
# state 초기화
params = {"w": jnp.zeros(10), "b": jnp.zeros(())}
opt_state = optimizer.init(params)
# 학습 step 안에서
@jax.jit
def step(params, opt_state, x, y):
grads = jax.grad(loss_fn)(params, x, y)
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state
3 단계 — update 로 grads 를 transform, apply_updates 로 params 에 더하기. PyTorch 의 optimizer.step() 한 줄과 의미 동등.
합성의 힘
Optax 의 매력 — optax.chain 으로 transformation 합성:
optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # gradient clip
optax.add_decayed_weights(1e-4), # weight decay
optax.scale_by_adam(), # Adam moments
optax.scale_by_schedule(schedule), # learning rate schedule
optax.scale(-1.0), # 부호 반전 (descent)
)
이 5 줄이 — AdamW + grad clip + scheduled lr 의 표준 trainer 의 optimizer. PyTorch 면 — Adam 클래스 + grad clip 매 step 명시 + schedule callback. JAX/Optax 는 — 한 chain.
주요 transformation
| 이름 | 역할 |
|---|---|
scale(c) | 모든 grad 에 c 곱하기 (보통 -lr) |
scale_by_adam() | Adam moments (1차, 2차) |
scale_by_belief() | AdaBelief 의 variance estimate |
scale_by_rms() | RMSProp 의 second moment |
scale_by_schedule(s) | schedule 에 따라 lr 변화 |
add_decayed_weights(wd) | L2 weight decay |
clip(max) | per-element clip |
clip_by_global_norm(max) | 전체 grad norm clip |
ema(decay) | exponential moving average |
masked(t, mask) | 일부 leaf 에만 transform 적용 |
실전 예 — bias 와 norm 은 weight decay 안 함
def make_mask(params):
'''W 면 True, b/gamma/beta 면 False'''
def is_weight(path, value):
return value.ndim > 1 # 1D 이상이면 weight matrix
return jax.tree_util.tree_map_with_path(is_weight, params)
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.masked(
optax.add_decayed_weights(1e-4),
mask=make_mask(params),
),
optax.scale_by_adam(),
optax.scale_by_schedule(cosine_schedule),
optax.scale(-1.0),
)
built-in optimizer
# 자주 쓰는 표준
optax.sgd(lr, momentum=0.9)
optax.adam(lr)
optax.adamw(lr, weight_decay=1e-4)
optax.adamax(lr)
optax.rmsprop(lr)
optax.lamb(lr)
optax.lion(lr, weight_decay=1e-4) # 2023 의 Google
optax.amsgrad(lr)
각각 — 위 transformation 들의 미리 정의된 chain. 무엇을 chain 하는지 — Optax source 에서 한 번 보면 학습 됨.
🧬 합성으로서 optimizer
Optax 의 정신 — JAX 의 정신 (jit + grad + vmap) 의 그대로. monolithic class 대신 — 작은 transformation 의 합성. 새 optimizer 만든다 = 기존 transformation 을 다른 순서로 chain. AdamW 가 Adam + decay 의 chain 인 것처럼. 이 framework 안에서 — 학술 논문의 새 optimizer 를 한 시간 안에 구현 가능.
Track 11-3 에서 schedule, 11-4 에서 전체 학습 루프. 11-5 에서 scan + checkpoint.