fit() 의 default step 이 부족할 때 — GAN, distillation, dynamic loss weight 등 — train_step(self, data) override. 안에서 GradientTape (TF) / autograd (PyTorch) / jax.grad (JAX) 로 직접 forward + backward.
핵심: train_step 은 dict 반환해야 함 (loss/metrics). callback / progress bar / EarlyStopping 등 fit() 의 인프라가 그 dict 으로부터 동작. 직접 loop 짜는 것 보다 deep custom 시 train_step override 가 권장.