패턴: class MyModel(keras.Model) 의 train_step(self, data) override. 안에서 forward + loss + gradient + apply 까지 직접. 끝에 metric dict 반환 (fit() 이 그걸 progress bar / log 에 표시).
backend 별 차이 — TF 는 tf.GradientTape, PyTorch 는 loss.backward() + optimizer.step(), JAX 는 jax.grad. Keras 3 가 backend 추상화는 layer 단에서 끝, training step 의 mechanics 는 backend native.