C.W.K.
Stream
Lesson 04 of 06 · published

GAN 학습 패턴

~8 min · custom-train

Level 0Keras 도제
0 XP0/97 lessons0/20 achievements
0/120 XP to next level120 XP to go0% complete

GAN: generator + discriminator 두 model, 두 optimizer, 번갈아 학습. (1) D 학습: real/fake 분류, D-loss 만 update. (2) G 학습: G 가 만든 fake 를 D 가 real 로 분류하게 학습, G-loss 만 update. fit() 의 단일 loss/optimizer 모델로 안 됨.

train_step override 가 정답: 안에서 두 GradientTape, 두 apply_gradients. metrics dict 에 d_loss / g_loss 둘 다 반환. fit() 의 callback / distribution 인프라 그대로.

Code

class <span class="dc">GAN</span>(keras.Model):
    def __init__(self, generator, discriminator, latent_dim):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.latent_dim = latent_dim

    def train_step(self, real_images):
        batch_size = keras.ops.shape(real_images)[0]
        noise = keras.random.normal(
            shape=(batch_size, self.latent_dim)
        )

        # Train discriminator
        fake_images = self.generator(noise)
        combined = keras.ops.concatenate([real_images, fake_images])
        labels = keras.ops.concatenate([
            keras.ops.ones((batch_size, 1)),
            keras.ops.zeros((batch_size, 1)),
        ])
        d_loss = self._train_discriminator(combined, labels)

        # Train generator
        noise = keras.random.normal(shape=(batch_size, self.latent_dim))
        misleading_labels = keras.ops.ones((batch_size, 1))
        g_loss = self._train_generator(noise, misleading_labels)

        return {{"d_loss": d_loss, "g_loss": g_loss}}

External links

Exercise

MNIST 용 작은 DCGAN 구현 — small generator (Dense → reshape → ConvTranspose), small discriminator. train_step override 사용. 5 epoch 학습, g_loss 감소 + sample 질적으로 개선 확인.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.