실전에서 자주 쓰는 grad variant 3 가지.
1. argnums — 어느 인자에 대한 미분이냐
def f(a, b, c):
return a ** 2 + b * c
# default — argnums=0
df_da = jax.grad(f)(2.0, 3.0, 4.0) # 4.0
# 특정 인자
df_db = jax.grad(f, argnums=1)(2.0, 3.0, 4.0) # 4.0
df_dc = jax.grad(f, argnums=2)(2.0, 3.0, 4.0) # 3.0
# 여러 인자 동시
g_a, g_b = jax.grad(f, argnums=(0, 1))(2.0, 3.0, 4.0)
2. value_and_grad — 값과 gradient 같이
학습 loop 의 정석 — loss 값과 gradient 둘 다 필요. jax.grad 만 부르면 — 한 번 더 forward 돌려야 loss 값 알 수 있어 (낭비). value_and_grad 는 한 forward + backward 로 둘 다.
def loss_fn(params, x, y):
pred = jnp.dot(x, params)
return jnp.mean((pred - y) ** 2)
# 비효율
loss = loss_fn(params, x, y)
grads = jax.grad(loss_fn)(params, x, y) # forward 두 번
# 효율
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
거의 항상 — 학습 코드에선 value_and_grad 쓴다. JAX 의 가장 흔한 무료 speedup.
3. has_aux — 추가 정보 함께 반환
학습 시 loss 외에 metrics (accuracy, perplexity 등) 도 함께 계산하고 싶다. 그런데 grad 는 scalar 출력만 받음. 해결 — has_aux=True 로 두 번째 출력은 "auxiliary, gradient 계산 안 함":
def loss_and_metrics(params, x, y):
pred = jnp.dot(x, params)
loss = jnp.mean((pred - y) ** 2)
metrics = {
"mae": jnp.mean(jnp.abs(pred - y)),
"max_err": jnp.max(jnp.abs(pred - y)),
}
return loss, metrics # tuple
(loss, metrics), grads = jax.value_and_grad(
loss_and_metrics, has_aux=True
)(params, x, y)
print(f"loss: {loss}, mae: {metrics['mae']}")
has_aux=True 로 — 첫 번째 출력에 대해서만 미분, 두 번째는 그대로 통과.
전형적 학습 step 패턴:
@jax.jit
def train_step(state, batch):
'''state: {"params": ..., "step": ...}, batch: (x, y)'''
def loss_and_metrics(params):
x, y = batch
pred = model_apply(params, x)
loss = compute_loss(pred, y)
metrics = {"acc": accuracy(pred, y)}
return loss, metrics
(loss, metrics), grads = jax.value_and_grad(
loss_and_metrics, has_aux=True
)(state["params"])
new_params = jax.tree.map(
lambda p, g: p - 0.001 * g, state["params"], grads
)
new_state = {"params": new_params, "step": state["step"] + 1}
return new_state, loss, metrics
💡 항상 value_and_grad
학습 코드에서 jax.grad 만 따로 쓰는 일은 거의 없어. jax.value_and_grad 가 default 라고 외워. metrics 함께 반환할 때 has_aux=True. 두 패턴이 90% 의 학습 코드를 cover.