C.W.K.
Stream
Lesson 07 of 10 · published

Gradient Clipping — 폭발하는 gradient 길들이기

~10 min · clipping, stability, norm

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

backward 가 너무 멀리 가면, clip

일부 training setup 이 huge norm 의 gradient 생성 — RNN 위 긴 sequence, 적절한 normalization 없는 매우 깊은 network, GAN, warmup 없는 transformer training. Huge gradient × 어떤 learning rate = huge update → training 불안정 (NaN loss, 갑작스런 발산).

Gradient clipping 은 optimizer step 전에 gradient 크기 제한. 두 흔한 변형:

  • Global norm 으로 clip — 모든 gradient 의 결합 L2 norm 계산, threshold 초과하면 down-scale. modern training (Transformer, LLM fine-tune) 에서 가장 흔함.
  • 값으로 clip — 각 gradient element 를 [-c, c] 로 clamp. 더 거칠지만, overall large norm 보다 outlier element 일 때 유용.

loop 안 위치

항상 backward()optimizer.step() 사이. backward 전 clip 은 의미 없음 (gradient 아직 없음); step 후 clip 은 너무 늦음.

전형적 max_norm 값은 transformer training 엔 1.0, RNN 엔 가끔 5.0. 올바른 값은 경험적 — training 중 실제 gradient norm log 하면 cap 자주 치는지 보임.

Code

Global norm clip — 표준 recipe·python
import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 10))
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

x = torch.randn(32, 100)
y = torch.randint(0, 10, (32,))

optimizer.zero_grad()
loss = loss_fn(model(x), y)
loss.backward()

# Clip before step — returns the original (pre-clip) total norm
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
print(f"pre-clip norm: {total_norm:.4f}")

optimizer.step()
값으로 clip — 더 거칠고, 가끔 유용·python
import torch
import torch.nn as nn

model = nn.Linear(10, 2)
loss = nn.functional.mse_loss(model(torch.randn(4, 10)), torch.randn(4, 2))
loss.backward()

# Clamps each element of every gradient tensor to [-0.5, 0.5]
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
Gradient norm logging — 진단, cap 없이·python
import torch
import torch.nn as nn

model = nn.Linear(10, 2)
loss = nn.functional.mse_loss(model(torch.randn(4, 10)), torch.randn(4, 2))
loss.backward()

# Compute the global norm WITHOUT clipping (pass max_norm=inf)
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float('inf'))
print(f"global grad norm: {norm:.4f}")

# In a training loop, log this every N steps. If it consistently exceeds
# your max_norm, your clip threshold is doing real work — that's a signal
# to investigate (lr too high? warmup needed? mixed-precision underflow?).

External links

Exercise

장난감 sequence task 에 작은 RNN 5 epoch clipping 없이 train. global grad norm 매 batch log. 보통 spike 봐. 이제 clip_grad_norm_(model.parameters(), 1.0) 켜고 다시 — loss curve 가 눈에 띄게 부드러워야 함.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.