C.W.K.
Stream
Lesson 06 of 07 · published

Custom training step

~8 min · subclass

Level 0Keras 도제
0 XP0/97 lessons0/20 achievements
0/120 XP to next level120 XP to go0% complete

fit() 의 default step 이 부족할 때 — GAN, distillation, dynamic loss weight 등 — train_step(self, data) override. 안에서 GradientTape (TF) / autograd (PyTorch) / jax.grad (JAX) 로 직접 forward + backward.

핵심: train_step 은 dict 반환해야 함 (loss/metrics). callback / progress bar / EarlyStopping 등 fit() 의 인프라가 그 dict 으로부터 동작. 직접 loop 짜는 것 보다 deep custom 시 train_step override 가 권장.

Code

class <span class="dc">CustomModel</span>(keras.Model):
    def train_step(self, data):
        x, y = data

        # Forward pass with gradient tracking
        y_pred = self(x, training=True)
        loss = self.compute_loss(y=y, y_pred=y_pred)

        # Compute and apply gradients
        gradients = self.optimizer.compute_gradients(loss, self.trainable_variables)
        self.optimizer.apply(gradients)

        # Update metrics
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)
        return {{m.name: m.result() for m in self.metrics}}

External links

Exercise

keras.Model subclass 의 custom train_step 안에 L2 weight regularization 을 gradient 계산에 직접 추가. 한 epoch 학습 후 loss 가 regularizer 포함하는지 확인.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.