C.W.K.
Stream
Lesson 02 of 07 · published

keras.Model subclass

~8 min · subclass

Level 0Keras 도제
0 XP0/97 lessons0/20 achievements
0/120 XP to next level120 XP to go0% complete

패턴: class MyModel(keras.Model); __init__ 에서 layer 들 self.dense1 = Dense(64) 등 정의; call(inputs) 에서 forward 정의. 그게 끝. 그 후 model = MyModel(); model.compile(...); model.fit(...) 같이 사용.

핵심 — __init__ 은 *layer 인스턴스 생성*만, *forward 계산 X*. call 이 forward. 이 분리가 중요해 — layer 는 한 번 만들고 (weight 한 번 생성) call 은 매 forward 마다 호출.

Code

import keras
from keras import layers

class <span class="dc">MyModel</span>(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = layers.Dense(128, activation="relu")
        self.dropout = layers.Dropout(0.3)
        self.dense2 = layers.Dense(10, activation="softmax")

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        x = self.dropout(x, training=training)
        return self.dense2(x)

model = MyModel()
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy")
model.fit(x_train, y_train, epochs=5)

External links

Exercise

MNIST classifier 를 keras.Model subclass 로 재구현. 첫 call() 전후 model.summary() 비교 — 차이 메모.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.