C.W.K.
Stream
Lesson 06 of 06 · published

GradientTape로 custom training loop

~13 min · custom-loop, gradient-tape, metrics

Level 0Level 0
0 XP0/78 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

fit()이 부족할 때

가끔 model.fit()이 충분히 유연하지 않을 때 있어 — GAN은 generator/discriminator 교차 업데이트 필요, custom weighting의 multi-task loss, 연구용 gradient surgery. GradientTape로 custom training loop 짜면 완전한 제어 가능.

train_step은 항상 loop 밖에서 @tf.function으로 데코. 첫 호출에서 trace + 컴파일, 이후 호출은 컴파일된 거 직접 실행. 이게 custom loop이 fit() 속도와 같아지는 이유.

Keras metric은 누적. metric.result()는 마지막 reset_state() 이후 running 평균 반환. epoch 사이 reset 까먹으면 epoch 2 metric에 epoch 1 데이터 섞여서 숫자 이상해 보여. epoch 요약 후 항상 reset.

Code

Custom training loop — canonical pattern·python
import tensorflow as tf
from tensorflow import keras

model = keras.Sequential([keras.Input(shape=(784,)),
                          keras.layers.Dense(256, activation='relu'),
                          keras.layers.Dense(10)])
optimizer = keras.optimizers.Adam(1e-3)
loss_fn   = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = keras.metrics.Mean(name='train_loss')
train_acc  = keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
val_loss   = keras.metrics.Mean(name='val_loss')
val_acc    = keras.metrics.SparseCategoricalAccuracy(name='val_accuracy')

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss = loss_fn(y, logits)
        loss += sum(model.losses)         # add regularization losses
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_loss.update_state(loss)
    train_acc.update_state(y, logits)
    return loss

@tf.function
def val_step(x, y):
    logits = model(x, training=False)
    val_loss.update_state(loss_fn(y, logits))
    val_acc.update_state(y, logits)

EPOCHS = 10
for epoch in range(EPOCHS):
    for x_b, y_b in train_ds:
        train_step(x_b, y_b)
    for x_b, y_b in val_ds:
        val_step(x_b, y_b)

    print(f"Epoch {epoch+1}/{EPOCHS} — "
          f"loss: {train_loss.result():.4f}  acc: {train_acc.result():.4f}  "
          f"val_loss: {val_loss.result():.4f}  val_acc: {val_acc.result():.4f}")

    # IMPORTANT: reset between epochs
    train_loss.reset_state(); train_acc.reset_state()
    val_loss.reset_state();   val_acc.reset_state()

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.