C.W.K.
Stream
Lesson 02 of 05 · published

argnums, value_and_grad, has_aux

~9 min · grad, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

실전에서 자주 쓰는 grad variant 3 가지.

1. argnums — 어느 인자에 대한 미분이냐

def f(a, b, c):
    return a ** 2 + b * c

# default — argnums=0
df_da = jax.grad(f)(2.0, 3.0, 4.0)        # 4.0

# 특정 인자
df_db = jax.grad(f, argnums=1)(2.0, 3.0, 4.0)  # 4.0
df_dc = jax.grad(f, argnums=2)(2.0, 3.0, 4.0)  # 3.0

# 여러 인자 동시
g_a, g_b = jax.grad(f, argnums=(0, 1))(2.0, 3.0, 4.0)

2. value_and_grad — 값과 gradient 같이

학습 loop 의 정석 — loss 값과 gradient 둘 다 필요. jax.grad 만 부르면 — 한 번 더 forward 돌려야 loss 값 알 수 있어 (낭비). value_and_grad 는 한 forward + backward 로 둘 다.

def loss_fn(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

# 비효율
loss = loss_fn(params, x, y)
grads = jax.grad(loss_fn)(params, x, y)  # forward 두 번

# 효율
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)

거의 항상 — 학습 코드에선 value_and_grad 쓴다. JAX 의 가장 흔한 무료 speedup.

3. has_aux — 추가 정보 함께 반환

학습 시 loss 외에 metrics (accuracy, perplexity 등) 도 함께 계산하고 싶다. 그런데 grad 는 scalar 출력만 받음. 해결 — has_aux=True 로 두 번째 출력은 "auxiliary, gradient 계산 안 함":

def loss_and_metrics(params, x, y):
    pred = jnp.dot(x, params)
    loss = jnp.mean((pred - y) ** 2)
    metrics = {
        "mae": jnp.mean(jnp.abs(pred - y)),
        "max_err": jnp.max(jnp.abs(pred - y)),
    }
    return loss, metrics  # tuple

(loss, metrics), grads = jax.value_and_grad(
    loss_and_metrics, has_aux=True
)(params, x, y)

print(f"loss: {loss}, mae: {metrics['mae']}")

has_aux=True 로 — 첫 번째 출력에 대해서만 미분, 두 번째는 그대로 통과.

전형적 학습 step 패턴:

@jax.jit
def train_step(state, batch):
    '''state: {"params": ..., "step": ...}, batch: (x, y)'''

    def loss_and_metrics(params):
        x, y = batch
        pred = model_apply(params, x)
        loss = compute_loss(pred, y)
        metrics = {"acc": accuracy(pred, y)}
        return loss, metrics

    (loss, metrics), grads = jax.value_and_grad(
        loss_and_metrics, has_aux=True
    )(state["params"])

    new_params = jax.tree.map(
        lambda p, g: p - 0.001 * g, state["params"], grads
    )
    new_state = {"params": new_params, "step": state["step"] + 1}
    return new_state, loss, metrics

💡 항상 value_and_grad

학습 코드에서 jax.grad 만 따로 쓰는 일은 거의 없어. jax.value_and_grad 가 default 라고 외워. metrics 함께 반환할 때 has_aux=True. 두 패턴이 90% 의 학습 코드를 cover.

Code

import jax
import jax.numpy as jnp

def f(x, y, z):
    return x ** 2 + y ** 3 + z

# Differentiate with respect to x (argument 0) — default
df_dx = jax.grad(f, argnums=0)
print(df_dx(2.0, 3.0, 4.0))  # 4.0 (d/dx of x^2 = 2x = 4)

# Differentiate with respect to y (argument 1)
df_dy = jax.grad(f, argnums=1)
print(df_dy(2.0, 3.0, 4.0))  # 27.0 (d/dy of y^3 = 3y^2 = 27)

# Differentiate with respect to multiple arguments
df_dxy = jax.grad(f, argnums=(0, 1))
grads = df_dxy(2.0, 3.0, 4.0)
print(grads)  # (4.0, 27.0) — tuple of gradients
import jax
import jax.numpy as jnp

def loss_fn(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

# Returns (value, gradient) tuple
val_and_grad_fn = jax.value_and_grad(loss_fn)

params = jnp.array([1.0, 2.0])
x = jnp.array([[1.0, 0.5], [0.3, 0.8]])
y = jnp.array([1.0, 0.5])

loss_val, grads = val_and_grad_fn(params, x, y)
print(f"Loss: {loss_val:.4f}")
print(f"Gradients: {grads}")
import jax
import jax.numpy as jnp

def loss_with_metrics(params, x, y):
    pred = jnp.dot(x, params)
    loss = jnp.mean((pred - y) ** 2)
    # Return loss AND extra info
    metrics = {
        'mse': loss,
        'predictions': pred,
        'max_error': jnp.max(jnp.abs(pred - y))
    }
    return loss, metrics  # (scalar, auxiliary)

# has_aux=True tells grad: the function returns (loss, aux),
# only differentiate with respect to the loss part
grad_fn = jax.grad(loss_with_metrics, has_aux=True)

params = jnp.array([1.0, 2.0])
x = jnp.array([[1.0, 0.5], [0.3, 0.8]])
y = jnp.array([1.0, 0.5])

grads, metrics = grad_fn(params, x, y)
print(f"Gradients: {grads}")
print(f"MSE: {metrics['mse']:.4f}")
print(f"Max error: {metrics['max_error']:.4f}")

# Combined: value_and_grad with has_aux
val_grad_fn = jax.value_and_grad(loss_with_metrics, has_aux=True)
(loss, metrics), grads = val_grad_fn(params, x, y)

External links

Exercise

(loss, metrics_dict) 반환하는 loss_fn. jax.value_and_grad 에 has_aux=True 로 wrap. loss, metrics, gradient shape 출력. gradient 가 metrics dict 무시하는 거 확인 — has_aux 가 unlock 하는 그것.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.