C.W.K.
Stream
Lesson 04 of 06 · published

model.fit() — 직접 안 써도 되는 training loop

~12 min · fit, history, validation

Level 0Level 0
0 XP0/78 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

호출 한 번이면 끝나는 training loop

model.fit()이 전체 loop 돌려 — forward, loss, backward, update, metric 추적, batch마다 epoch마다. validation, callback, 진행 표시도 처리.

반환되는 history 객체가 주요 진단 도구. history.history['loss'] vs history.history['val_loss'] 그려서 overfitting 실시간으로 봐. accuracy 등 추적한 metric도 마찬가지.

대용량 데이터엔 tf.data.Dataset 넘겨. NumPy array는 다 메모리에 올려. Dataset은 lazy하게 stream하고 preprocessing 병렬화. tf.data는 다음 track에서 깊이 다뤄 — 지금은 fit()과 매끄럽게 통합된다는 것만 알면 돼.

Code

Full MNIST training·python
import tensorflow as tf
from tensorflow import keras
from keras import layers

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test  = x_test.reshape(-1, 784).astype("float32") / 255.0

model = keras.Sequential([
    keras.Input(shape=(784,)),
    layers.Dense(256, activation='relu'),
    layers.BatchNormalization(),
    layers.Dropout(0.3),
    layers.Dense(10),
])

model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'],
)

history = model.fit(
    x_train, y_train,
    batch_size=128,
    epochs=20,
    validation_split=0.1,
    shuffle=True,
    verbose=1,
)

print(history.history.keys())
# dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])
Plot history — see overfitting·python
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.title('Accuracy')
plt.legend()
plt.show()

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.