C.W.K.
Stream
Lesson 04 of 06 · published

Model Subclassing — 최대 유연성

~13 min · subclassing, research, training-flag

Level 0Level 0
0 XP0/78 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

완전한 Python 제어가 필요할 때

keras.Model 서브클래싱은 Python으로 forward pass 제어 권한을 줘. __init__에서 layer 정의, call()에서 계산. 가장 유연한 접근 — dynamic architecture, 조건부 로직, 새로운 state 가진 recurrent cell, static graph에 안 맞는 연구 프로토타입에 써.

training 인자가 중요해. Dropout이랑 BatchNormalization 같은 layer는 training과 inference에서 다르게 작동해. model.fit은 자동으로 training=True 넘겨주지만 custom loop에선 명시적으로 — training step엔 model(x, training=True), eval/inference엔 model(x, training=False).

Subclassing trade-off: model.summary()가 덜 자세함. 저장/불러오기엔 get_config() + from_config() 필요. 배포 예정 model이면 진짜 subclassing이 필요한 게 아닌 한 Functional API가 일반적으로 권장.

Code

ResidualBlock + ResNet, subclassing style·python
import tensorflow as tf
from tensorflow import keras
from keras import layers

class ResidualBlock(keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = layers.Conv2D(filters, 3, padding='same', activation='relu')
        self.conv2 = layers.Conv2D(filters, 3, padding='same')
        self.bn1 = layers.BatchNormalization()
        self.bn2 = layers.BatchNormalization()
        self.add = layers.Add()
        self.relu = layers.Activation('relu')

    def call(self, inputs, training=False):
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = self.conv2(x)
        x = self.bn2(x, training=training)
        x = self.add([x, inputs])    # skip connection
        return self.relu(x)


class MyResNet(keras.Model):
    def __init__(self, num_classes=10, **kwargs):
        super().__init__(**kwargs)
        self.stem = layers.Conv2D(32, 3, padding='same', activation='relu')
        self.block1 = ResidualBlock(32)
        self.pool = layers.GlobalAveragePooling2D()
        self.dropout = layers.Dropout(0.3)
        self.classifier = layers.Dense(num_classes, activation='softmax')

    def call(self, inputs, training=False):
        x = self.stem(inputs)
        x = self.block1(x, training=training)
        x = self.pool(x)
        x = self.dropout(x, training=training)
        return self.classifier(x)

model = MyResNet(num_classes=10)
model.build(input_shape=(None, 32, 32, 32))   # explicit build
model.summary()

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.