C.W.K.
Stream
Lesson 02 of 05 · published

Optax: 합성 가능한 Gradient Transformation

~8 min · training, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

Optax 는 — JAX 의 사실상 표준 optimizer 라이브러리. 핵심 아이디어 — optimizer 가 monolithic class 가 아니라 — 작은 transformation 들의 합성.

pip install optax

가장 단순한 사용

import optax
import jax
import jax.numpy as jnp

# AdamW optimizer
optimizer = optax.adamw(learning_rate=1e-3, weight_decay=1e-4)

# state 초기화
params = {"w": jnp.zeros(10), "b": jnp.zeros(())}
opt_state = optimizer.init(params)

# 학습 step 안에서
@jax.jit
def step(params, opt_state, x, y):
    grads = jax.grad(loss_fn)(params, x, y)
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

3 단계 — update 로 grads 를 transform, apply_updates 로 params 에 더하기. PyTorch 의 optimizer.step() 한 줄과 의미 동등.

합성의 힘

Optax 의 매력 — optax.chain 으로 transformation 합성:

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),       # gradient clip
    optax.add_decayed_weights(1e-4),      # weight decay
    optax.scale_by_adam(),                 # Adam moments
    optax.scale_by_schedule(schedule),     # learning rate schedule
    optax.scale(-1.0),                     # 부호 반전 (descent)
)

이 5 줄이 — AdamW + grad clip + scheduled lr 의 표준 trainer 의 optimizer. PyTorch 면 — Adam 클래스 + grad clip 매 step 명시 + schedule callback. JAX/Optax 는 — 한 chain.

주요 transformation

이름역할
scale(c)모든 grad 에 c 곱하기 (보통 -lr)
scale_by_adam()Adam moments (1차, 2차)
scale_by_belief()AdaBelief 의 variance estimate
scale_by_rms()RMSProp 의 second moment
scale_by_schedule(s)schedule 에 따라 lr 변화
add_decayed_weights(wd)L2 weight decay
clip(max)per-element clip
clip_by_global_norm(max)전체 grad norm clip
ema(decay)exponential moving average
masked(t, mask)일부 leaf 에만 transform 적용

실전 예 — bias 와 norm 은 weight decay 안 함

def make_mask(params):
    '''W 면 True, b/gamma/beta 면 False'''
    def is_weight(path, value):
        return value.ndim > 1   # 1D 이상이면 weight matrix
    return jax.tree_util.tree_map_with_path(is_weight, params)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.masked(
        optax.add_decayed_weights(1e-4),
        mask=make_mask(params),
    ),
    optax.scale_by_adam(),
    optax.scale_by_schedule(cosine_schedule),
    optax.scale(-1.0),
)

built-in optimizer

# 자주 쓰는 표준
optax.sgd(lr, momentum=0.9)
optax.adam(lr)
optax.adamw(lr, weight_decay=1e-4)
optax.adamax(lr)
optax.rmsprop(lr)
optax.lamb(lr)
optax.lion(lr, weight_decay=1e-4)   # 2023 의 Google
optax.amsgrad(lr)

각각 — 위 transformation 들의 미리 정의된 chain. 무엇을 chain 하는지 — Optax source 에서 한 번 보면 학습 됨.

🧬 합성으로서 optimizer

Optax 의 정신 — JAX 의 정신 (jit + grad + vmap) 의 그대로. monolithic class 대신 — 작은 transformation 의 합성. 새 optimizer 만든다 = 기존 transformation 을 다른 순서로 chain. AdamW 가 Adam + decay 의 chain 인 것처럼. 이 framework 안에서 — 학술 논문의 새 optimizer 를 한 시간 안에 구현 가능.

Track 11-3 에서 schedule, 11-4 에서 전체 학습 루프. 11-5 에서 scan + checkpoint.

Code

import optax

# Common optimizers — each is a gradient transformation
optimizer = optax.adam(learning_rate=1e-3)
optimizer = optax.adamw(learning_rate=1e-3, weight_decay=0.01)
optimizer = optax.sgd(learning_rate=0.1, momentum=0.9)
optimizer = optax.lion(learning_rate=1e-4)  # newer optimizer

# But the real power is composition
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),  # gradient clipping
    optax.adam(learning_rate=1e-3),   # Adam optimizer
)

# Or even more custom:
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),      # clip gradients
    optax.scale_by_adam(),                # Adam scaling (no LR)
    optax.add_decayed_weights(0.01),     # L2 regularization
    optax.scale(-1e-3),                  # apply learning rate
)
import jax
import jax.numpy as jnp
import optax

# 1. Create optimizer
optimizer = optax.adamw(learning_rate=1e-3)

# 2. Initialize optimizer state from params
params = {'w': jnp.ones((3, 4)), 'b': jnp.zeros(4)}
opt_state = optimizer.init(params)

# 3. Compute gradients (however you like)
grads = jax.grad(loss_fn)(params, x, y)

# 4. Get updates from optimizer
updates, new_opt_state = optimizer.update(grads, opt_state, params)

# 5. Apply updates to params
new_params = optax.apply_updates(params, updates)

# The full cycle:
# grads → optimizer.update(grads, opt_state, params) → updates, new_opt_state
#                                                      ↓
# new_params = optax.apply_updates(params, updates)
# PyTorch                           # JAX + Optax
# optimizer = Adam(model.params())  # optimizer = optax.adam(1e-3)
#                                   # opt_state = optimizer.init(params)
# optimizer.zero_grad()             # (not needed — grads are values)
# loss.backward()                   # grads = jax.grad(loss_fn)(params)
# optimizer.step()                  # updates, opt_state = optimizer.update(...)
#                                   # params = optax.apply_updates(params, updates)

External links

Exercise

AdamW + clip_by_global_norm + ema 를 optax.chain 으로 합성. 1 optimizer step 실행. state 구조 확인. '4 줄 합성으로 gradient transform = Lego' 철학.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.