fit() 안 쓰고 직접 epoch loop 짤 때: for epoch in range(N); for batch in dataset; with GradientTape; forward + loss; tape.gradient; optimizer.apply_gradients. callback 없으니까 logging / checkpoint / EarlyStopping 다 직접.
이 자유의 대가는 인프라 부재. progress bar 안 떠 (tqdm 직접), checkpoint 매 N batch 직접 저장, validation 도 직접 호출. 짧은 실험엔 OK, production 엔 train_step override 가 더 단단.
백엔드 노트:
⚙️ Backend Note
Code
# Backend-agnostic manual training loop
optimizer = keras.optimizers.Adam(1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy()
for epoch in range(10):
for step, (x_batch, y_batch) in enumerate(train_dataset):
# Forward pass with gradient tracking
y_pred = model(x_batch, training=True)
loss = loss_fn(y_batch, y_pred)
# Compute gradients and update weights
grads = optimizer.compute_gradients(
loss, model.trainable_variables
)
optimizer.apply(grads)
print(f"Epoch {{epoch}}, Loss: {{loss:.4f}}")