C.W.K.
Stream
Lesson 06 of 06 · published

Gradient clipping과 stop_gradient

~11 min · clipping, stop-gradient, exploding-gradient, rl

Level 0Level 0
0 XP0/78 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

Gradient가 말썽일 때

Gradient가 항상 얌전하진 않아. Exploding gradient는 training 중 무한정 자라 (긴 시퀀스의 RNN, Transformer에서 특히), 무한대 크기의 weight update 일으켜. Vanishing gradient는 0으로 줄어들어서 초기 layer 학습을 막아.

Gradient clipping이 폭발의 표준 처방. 두 전략:

  • Global norm clip (clipnorm=1.0): 모든 gradient를 합산 L2 norm이 max_norm 이하가 되도록 scale. RNN, Transformer 권장.
  • Value clip (clipvalue=0.5): 각 gradient 원소를 독립적으로 [-clipvalue, +clipvalue]로 잘라냄.

tf.stop_gradient는 computation graph의 일부에서 gradient 흐름 막아. 감싸진 tensor가 gradient 계산엔 constant처럼 보여. 흔한 용도: DQN의 target network, contrastive learning의 momentum encoder, 값을 계산은 하지만 backprop으로 업데이트되면 안 되는 모든 경우.

Code

Gradient clipping — two strategies·python
import tensorflow as tf

# Strategy 1: clip by global norm — RNN/Transformer recommended
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3, clipnorm=1.0)

# Strategy 2: clip by value
optimizer_val = tf.keras.optimizers.Adam(learning_rate=1e-3, clipvalue=0.5)

# Manual clipping in a custom loop
with tf.GradientTape() as tape:
    loss = loss_fn(model(x, training=True), y)

grads = tape.gradient(loss, model.trainable_weights)
grads, global_norm = tf.clip_by_global_norm(grads, clip_norm=1.0)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
stop_gradient — DQN target pattern·python
import tensorflow as tf

# In DQN, the "target" Q-network must NOT receive gradients
# from the current loss — otherwise the target moves with the predictor
# and training diverges.

target_q = tf.stop_gradient(target_model(next_states))
loss = tf.reduce_mean(tf.square(
    current_q - (rewards + gamma * target_q)
))

# stop_gradient is also how you implement gradient blocking in
# contrastive learning, BYOL-style momentum encoders, and any
# place where you want a "snapshot" of a model's output.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.