C.W.K.
Stream
Lesson 02 of 06 · published

train_step() override

~8 min · custom-train

Level 0Keras 도제
0 XP0/97 lessons0/20 achievements
0/120 XP to next level120 XP to go0% complete

패턴: class MyModel(keras.Model)train_step(self, data) override. 안에서 forward + loss + gradient + apply 까지 직접. 끝에 metric dict 반환 (fit() 이 그걸 progress bar / log 에 표시).

backend 별 차이 — TF 는 tf.GradientTape, PyTorch 는 loss.backward() + optimizer.step(), JAX 는 jax.grad. Keras 3 가 backend 추상화는 layer 단에서 끝, training step 의 mechanics 는 backend native.

Code

class <span class="dc">CustomModel</span>(keras.Model):
    def train_step(self, data):
        x, y = data

        # Forward pass + loss
        y_pred = self(x, training=True)
        loss = self.compute_loss(y=y, y_pred=y_pred)

        # Compute and apply gradients
        grads = self.optimizer.compute_gradients(
            loss, self.trainable_variables
        )
        self.optimizer.apply(grads)

        # Update and return metrics
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        return {{m.name: m.result() for m in self.metrics}}

# Still use fit()!
model = CustomModel(...)
model.compile(optimizer="adam", loss="mse")
model.fit(x_train, y_train, epochs=10)

External links

Exercise

keras.Model subclass 의 train_step 에 L2 regularization 항목을 gradient 에 직접 추가. MNIST 한 epoch 학습, regularizer 가 loss 에 나타나는지 확인.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.