C.W.K.
Stream
Lesson 01 of 05 · published

Custom Gradient Rules: custom_vjp / custom_jvp

~8 min · advanced, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

때로 — autodiff 가 만든 gradient 가 마음에 안 들 수 있어. 너무 느리거나, 수치적으로 불안정하거나, custom 의미를 부여하고 싶거나. jax.custom_vjp / jax.custom_jvp 로 직접 정의.

예: numerically stable softmax-cross-entropy

import jax
import jax.numpy as jnp

# 단순 구현 — log(softmax(x)) 가 underflow 위험
def naive_loss(logits, target):
    p = jax.nn.softmax(logits)
    return -jnp.sum(target * jnp.log(p + 1e-10))

# JAX 의 jax.nn.log_softmax — 안정
def stable_loss(logits, target):
    log_p = jax.nn.log_softmax(logits)
    return -jnp.sum(target * log_p)

여기까진 standard 함수로 충분. custom gradient 가 필요한 경우 — 예를 들어 quantization 의 straight-through estimator (STE):

@jax.custom_vjp
def ste_round(x):
    '''forward 에선 round, backward 에선 identity'''
    return jnp.round(x)

def ste_round_fwd(x):
    return jnp.round(x), x   # residual = x (저장)

def ste_round_bwd(x_residual, grad_out):
    return (grad_out,)   # grad 가 그대로 통과

ste_round.defvjp(ste_round_fwd, ste_round_bwd)

# 사용
def loss(x):
    return jnp.sum(ste_round(x) ** 2)

grad = jax.grad(loss)(jnp.array([1.5, 2.7, 3.1]))
print(grad)   # gradient 가 흐름 (round 가 미분 0 인데도)

round 는 — 거의 모든 점에서 미분이 0. 그러나 STE 는 — 마치 identity 인 것처럼 gradient 를 통과시켜 학습 가능하게.

custom_jvp — forward mode

@jax.custom_jvp
def square(x):
    return x ** 2

@square.defjvp
def square_jvp(primals, tangents):
    x, = primals
    dx, = tangents
    primal_out = x ** 2
    tangent_out = 2 * x * dx   # 우리가 정의한 forward derivative
    return primal_out, tangent_out

대부분의 use case — custom_vjp (reverse mode). custom_jvp 는 forward mode 가 더 효율적인 경우.

실전 예: 외부 solver 의 gradient

@jax.custom_vjp
def ode_solve(initial_state, t_final, params):
    '''black-box solver 호출'''
    return some_ode_solver(initial_state, t_final, params)

def ode_solve_fwd(initial_state, t_final, params):
    final_state = some_ode_solver(initial_state, t_final, params)
    return final_state, (initial_state, t_final, params, final_state)

def ode_solve_bwd(residuals, grad_final):
    '''adjoint method 로 gradient 계산'''
    initial_state, t_final, params, final_state = residuals
    # 별도 ODE 를 풀어 gradient 얻음
    grad_initial, grad_params = solve_adjoint(...)
    return (grad_initial, None, grad_params)

ode_solve.defvjp(ode_solve_fwd, ode_solve_bwd)

scientific computing 에서 — solver 자체를 differentiable 하게 만드는 표준 패턴. Diffrax (Track 13) 가 이런 걸 깔끔히 wrapping.

🔬 언제 custom gradient?

(1) numerical stability — log-sum-exp, log(softmax) 같은 수치 안정 표현. (2) 미분 불가능 — round, argmax, sample. STE 패턴. (3) 외부 solver — ODE, fixed-point iteration, optimizer-as-layer. adjoint method. (4) 새로운 정규화 — weight constraint, manifold projection. 일반적으로 — 99% 의 코드는 standard autodiff 면 충분. custom 은 — 신호 강할 때만.

위 패턴은 — JAX 의 가장 power-user 영역. 그러나 한 번 익히면 — research 코드의 표현력 폭발적으로 증가.

Code

import jax
import jax.numpy as jnp

@jax.custom_vjp
def safe_log(x):
    """Log that's numerically stable in the forward pass."""
    return jnp.log(jnp.maximum(x, 1e-7))

def safe_log_fwd(x):
    """Forward pass: returns (output, residuals for backward)."""
    result = safe_log(x)
    return result, x  # save x for the backward pass

def safe_log_bwd(x, g):
    """Backward pass: g is the incoming gradient, x is the saved residual."""
    # Clip gradient to prevent explosion near zero
    return (g / jnp.maximum(x, 1e-7),)

safe_log.defvjp(safe_log_fwd, safe_log_bwd)

# Now gradients are clipped automatically
x = jnp.array([1.0, 0.01, 0.0001, 0.0])
grads = jax.grad(lambda x: jnp.sum(safe_log(x)))(x)
print(grads)  # [1.0, 100.0, 10000.0, 10000000.0] — capped, not inf
@jax.custom_vjp
def straight_through_round(x):
    """Round in forward pass, identity gradient in backward."""
    return jnp.round(x)

def ste_fwd(x):
    return straight_through_round(x), x

def ste_bwd(x, g):
    return (g,)  # gradient passes through unchanged

straight_through_round.defvjp(ste_fwd, ste_bwd)

# Forward: rounds values. Backward: gradient flows as if no rounding happened
x = jnp.array([1.3, 2.7, -0.5])
print(straight_through_round(x))  # [1., 3., -0.]
print(jax.grad(lambda x: jnp.sum(straight_through_round(x)))(x))  # [1., 1., 1.]
@jax.custom_jvp
def my_relu(x):
    return jnp.maximum(x, 0.0)

@my_relu.defjvp
def my_relu_jvp(primals, tangents):
    (x,) = primals
    (x_dot,) = tangents
    primal_out = my_relu(x)
    # Custom: use sigmoid as a smooth approximation of the step function
    tangent_out = x_dot * jax.nn.sigmoid(10.0 * x)
    return primal_out, tangent_out

External links

Exercise

jax.custom_vjp 로 sin(x) 의 gradient override 해서 0 반환 (의도적으로 우스움). jax.grad 가 일치하는 거 검증. 그 다음 real use case: custom_vjp 로 numerical-stable softmax-cross-entropy. 드물지만 필요할 때 결정적 기법.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.