C.W.K.
Stream
Lesson 04 of 05 · published

TF와 JAX — Google의 두 framework

~11 min · jax, convergence, stablehlo

Level 0Level 0
0 XP0/78 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

왜 Google은 두 framework 유지하나 (그리고 너한테 의미)

Google은 주요 ML framework 두 개 유지: TensorFlow (2015), JAX (2018). 차이 이해가 2026년 TF ecosystem에서 일하는 누구에게나 필수 context.

차원TensorFlowJAX
ParadigmOO, statefulFunctional, pure function
Model state객체 속성명시적 pytree
핵심 optf.function, GradientTapejit, grad, vmap, pmap
난수전역 RNG명시적 PRNG key 분할
디버그 경험좋음 (eager default)jit 아래 어려움
TPU 성능좋음 (XLA 통해)최고 수준
Production 배포우수 (Serving, TFLite, TF.js)어려움 (jax2tf bridge 필요)

Google 첨단 LLM (Gemini, PaLM 2, Gemma)은 JAX에서 훈련. TensorFlow는 production serving과 mobile 배포에 지배적. 호환 bridge는 jax2tf — TF Serving / TFLite 배포 위해 JAX 함수를 TF로 변환.

더 큰 그림: TF, JAX, PyTorch가 XLA 하드웨어용 공통 중간 표현 StableHLO로 수렴 중. 장기적으로 짜는 framework보다 배포 타깃이 더 중요해 — Keras 3가 이 수렴의 가장 가시적 부분.

Code

jax2tf — JAX function to TF SavedModel·python
from jax.experimental import jax2tf
import jax.numpy as jnp
import tensorflow as tf

def jax_model(x):
    return jnp.sin(jnp.cos(x))

# Convert JAX function to TF (for TF Serving / TFLite)
tf_model = jax2tf.convert(jax_model)
result = tf_model(tf.constant([1.0, 2.0]))

# Export as TF SavedModel from JAX training
tf_func = tf.function(tf_model, autograph=False)
module   = tf.Module()
module.f = tf_func
tf.saved_model.save(module, "/tmp/jax_as_tf_saved_model")

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.