때로 — autodiff 가 만든 gradient 가 마음에 안 들 수 있어. 너무 느리거나, 수치적으로 불안정하거나, custom 의미를 부여하고 싶거나. jax.custom_vjp / jax.custom_jvp 로 직접 정의.
예: numerically stable softmax-cross-entropy
import jax
import jax.numpy as jnp
# 단순 구현 — log(softmax(x)) 가 underflow 위험
def naive_loss(logits, target):
p = jax.nn.softmax(logits)
return -jnp.sum(target * jnp.log(p + 1e-10))
# JAX 의 jax.nn.log_softmax — 안정
def stable_loss(logits, target):
log_p = jax.nn.log_softmax(logits)
return -jnp.sum(target * log_p)
여기까진 standard 함수로 충분. custom gradient 가 필요한 경우 — 예를 들어 quantization 의 straight-through estimator (STE):
@jax.custom_vjp
def ste_round(x):
'''forward 에선 round, backward 에선 identity'''
return jnp.round(x)
def ste_round_fwd(x):
return jnp.round(x), x # residual = x (저장)
def ste_round_bwd(x_residual, grad_out):
return (grad_out,) # grad 가 그대로 통과
ste_round.defvjp(ste_round_fwd, ste_round_bwd)
# 사용
def loss(x):
return jnp.sum(ste_round(x) ** 2)
grad = jax.grad(loss)(jnp.array([1.5, 2.7, 3.1]))
print(grad) # gradient 가 흐름 (round 가 미분 0 인데도)
round 는 — 거의 모든 점에서 미분이 0. 그러나 STE 는 — 마치 identity 인 것처럼 gradient 를 통과시켜 학습 가능하게.
custom_jvp — forward mode
@jax.custom_jvp
def square(x):
return x ** 2
@square.defjvp
def square_jvp(primals, tangents):
x, = primals
dx, = tangents
primal_out = x ** 2
tangent_out = 2 * x * dx # 우리가 정의한 forward derivative
return primal_out, tangent_out
대부분의 use case — custom_vjp (reverse mode). custom_jvp 는 forward mode 가 더 효율적인 경우.
실전 예: 외부 solver 의 gradient
@jax.custom_vjp
def ode_solve(initial_state, t_final, params):
'''black-box solver 호출'''
return some_ode_solver(initial_state, t_final, params)
def ode_solve_fwd(initial_state, t_final, params):
final_state = some_ode_solver(initial_state, t_final, params)
return final_state, (initial_state, t_final, params, final_state)
def ode_solve_bwd(residuals, grad_final):
'''adjoint method 로 gradient 계산'''
initial_state, t_final, params, final_state = residuals
# 별도 ODE 를 풀어 gradient 얻음
grad_initial, grad_params = solve_adjoint(...)
return (grad_initial, None, grad_params)
ode_solve.defvjp(ode_solve_fwd, ode_solve_bwd)
scientific computing 에서 — solver 자체를 differentiable 하게 만드는 표준 패턴. Diffrax (Track 13) 가 이런 걸 깔끔히 wrapping.
🔬 언제 custom gradient?
(1) numerical stability — log-sum-exp, log(softmax) 같은 수치 안정 표현. (2) 미분 불가능 — round, argmax, sample. STE 패턴. (3) 외부 solver — ODE, fixed-point iteration, optimizer-as-layer. adjoint method. (4) 새로운 정규화 — weight constraint, manifold projection. 일반적으로 — 99% 의 코드는 standard autodiff 면 충분. custom 은 — 신호 강할 때만.
위 패턴은 — JAX 의 가장 power-user 영역. 그러나 한 번 익히면 — research 코드의 표현력 폭발적으로 증가.