C.W.K.
Stream
Lesson 05 of 08 · published

JAX Backend

~8 min · backend

Level 0Keras 도제
0 XP0/97 lessons0/20 achievements
0/120 XP to next level120 XP to go0% complete

JAX backend 는 *연구·TPU·HPC* 의 home. jit·vmap·pmap 의 함수형 변환이 뼈대고, XLA 컴파일러 통해 GPU/TPU 에 끔찍하게 빠른 실행. Google 내부 large-scale 학습 (Gemini, PaLM 등) 의 default.

Keras 3 + JAX 면 fit() 의 편의를 누리면서 jax.grad 로 직접 gradient 도 계산 가능. functional purity 가 강제돼서 mutable state 에 익숙한 PyTorch 사용자는 처음에 어색할 수 있어.

백엔드 노트:
⚙️ Backend Note

Code

os.environ["KERAS_BACKEND"] = "jax"
import keras

model = keras.Sequential([...])
model.compile(optimizer="adam", loss="mse")

# JAX stateless API for functional purity
variables = model.variables
outputs = model.stateless_call(variables, inputs)

External links

Exercise

KERAS_BACKEND=jax 로 작은 keras.ops.matmul 스니펫을 @jax.jit 데코레이터 함수 안에 넣어. 1000×1000 matmul 의 jit 유/무 시간 비교.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.