C.W.K.
Stream
Lesson 09 of 10 · published

torch.func — Per-Sample Gradient 와 Vmap

~12 min · torch.func, vmap, grad, jacobian

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

Functional autograd, batched transform

torch.func (옛 standalone functorch 의 in-tree 대체) 가 JAX-style function transform 줘: grad, vmap, jacrev, hessian. single example 에 동작하는 코드를 짜면 batch 에 자동 vectorize.

의외로 자주 등장하는 두 실용 use case:

  1. Per-sample gradient. 표준 .backward() 는 batch 의 합쳐진 loss 의 gradient 줘 — 모든 parameter 에 gradient 하나. 일부 research (differential privacy, influence function, GradSAM) 는 모든 individual sample 의 loss 의 gradient 필요. torch.func.vmap(grad(...)) 가 Python loop 없이 그걸 줘.
  2. Higher-order gradient. Hessian-vector product, second-order optimizer, meta-learning — 모두 gradient 의 gradient 필요. torch.func.grad(grad(f)) 가 깔끔히 조합.

정신적 전환

표준 PyTorch: tensor 가 implicit graph state 들고, .backward() 호출. torch.func: input 과 parameter 의 pure function, transform 이 새 pure function 생산. 작지만 실재하는 mindset 변화 — JAX 에 더 가까움, 정상 training 에 살짝 덜 ergonomic, 위 case 엔 훨씬 더 강력.

Code

function transform 으로서의 grad·python
import torch
from torch.func import grad

def f(x):
    return torch.sin(x) * x

# grad(f) is a NEW function: x → df/dx
df_dx = grad(f)

x = torch.tensor(1.0)
print(df_dx(x))            # cos(1)*1 + sin(1) ≈ 0.541 + 0.841 = 1.381

# Higher order — grad(grad(f))
d2f_dx2 = grad(grad(f))
print(d2f_dx2(x))          # -sin(1)*1 + 2*cos(1) ≈ ...
vmap — single-example function 자동-batch·python
import torch
from torch.func import vmap, grad

def loss_per_example(w, x, y):
    pred = (w * x).sum()
    return (pred - y) ** 2

w = torch.randn(4)
batch_x = torch.randn(8, 4)         # batch of 8
batch_y = torch.randn(8)

# Single-sample gradient: grad w.r.t. w
single_grad = grad(loss_per_example)

# Vectorize over the batch dim of x and y (in_dims=(None, 0, 0))
per_sample_grads = vmap(single_grad, in_dims=(None, 0, 0))(w, batch_x, batch_y)
print(per_sample_grads.shape)  # torch.Size([8, 4]) — one gradient per sample
nn.Module 의 per-sample gradient — modern way·python
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, grad

model = nn.Linear(4, 2)
params = dict(model.named_parameters())

def compute_loss(params, x, y):
    pred = functional_call(model, params, (x.unsqueeze(0),)).squeeze(0)
    return ((pred - y) ** 2).mean()

batch_x = torch.randn(16, 4)
batch_y = torch.randn(16, 2)

# Per-sample gradient w.r.t. params
per_sample_grad = vmap(grad(compute_loss), in_dims=(None, 0, 0))(
    params, batch_x, batch_y
)
print({k: v.shape for k, v in per_sample_grad.items()})
# {'weight': torch.Size([16, 2, 4]), 'bias': torch.Size([16, 2])}

External links

Exercise

2-layer MLP 잡기. 32 example 의 batch 에 torch.func.vmap(grad(...)) 로 per-sample gradient 계산. Python loop 로 한 sample 씩 gradient 계산해서 일치 확인. 둘 다 시간 — vmap 이 한 자릿수 더 빨라야 함.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.