research 코드가 잘 돌아도 — production 인프라가 JAX 가 아니면? 표준 답: jax2tf 로 TensorFlow SavedModel 출력.
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tensorflow as tf
# JAX model
@jax.jit
def my_model(params, x):
return jax.nn.softmax(x @ params["W"] + params["b"])
# JAX → TF function
tf_fn = jax2tf.convert(my_model, with_gradient=False)
# wrapper Module
class MyTfModule(tf.Module):
def __init__(self, params):
self.params = tf.nest.map_structure(tf.Variable, params)
@tf.function(input_signature=[tf.TensorSpec((None, 784), tf.float32)])
def serve(self, x):
return tf_fn(self.params, x)
# 저장
module = MyTfModule(params)
tf.saved_model.save(module, "my_saved_model/")
저장 후 — 다른 곳에서 TF 로 로드:
loaded = tf.saved_model.load("my_saved_model/")
y = loaded.serve(x_input) # JAX 가 없는 환경에서도 동작
TF Serving, TF Lite (mobile), TFJS (브라우저), Vertex AI 등 — 다 호환.
onnx export
JAX → ONNX 직접 변환은 — 여전히 (2026) 미숙. 전형적 경로 — JAX → TF → ONNX:
import tf2onnx
# 위에서 저장한 SavedModel 을 ONNX 로
spec = (tf.TensorSpec((None, 784), tf.float32, name="input"),)
model_proto, _ = tf2onnx.convert.from_keras(
module,
input_signature=spec,
output_path="model.onnx",
)
ONNX 의 호환성이 점점 좋아지고 있어 — TensorRT, OpenVINO, CoreML 등 다양한 inference engine 으로.
StableHLO — 새로운 IR
2024 년부터 — Google 이 StableHLO 를 cross-framework IR 로 밀고 있음. JAX, TF, PyTorch 모두 export 가능. inference 인프라 (XLA 기반) 가 직접 받음:
from jax.experimental import export
# JAX 함수 → StableHLO
exported = export.export(my_model)(params, jax.ShapeDtypeStruct((None, 784), jnp.float32))
serialized = exported.serialize()
# 다른 곳에서 로드
loaded = export.deserialize(serialized)
y = loaded.call(params, x_input)
장기적으론 — StableHLO 가 ONNX 의 자리를 대체할 가능성. 단기 — TF SavedModel 이 가장 검증된 경로.
ONNX 우회 — JAX 직접 deploy
인프라가 JAX 가능하면 — 그냥 JAX 로 deploy:
# FastAPI server
from fastapi import FastAPI
import jax
import pickle
app = FastAPI()
# load params at startup
with open("params.pkl", "rb") as f:
params = pickle.load(f)
# pre-compile
dummy_input = jnp.zeros((1, 784))
jit_predict = jax.jit(my_model)
_ = jit_predict(params, dummy_input) # warm-up
@app.post("/predict")
def predict(data: dict):
x = jnp.array(data["input"])
y = jit_predict(params, x)
return {"output": y.tolist()}
JAX 가 production 에서 잘 도는 게 — Google 내부에선 일반적. 외부에선 TF / ONNX 변환이 더 흔한 게 인프라 친화도 차이.
🚀 deploy 전략 결정
(1) 인프라가 TF/Vertex AI 위주 — jax2tf → SavedModel. (2) 모바일 / edge — jax2tf → TF Lite. (3) 브라우저 — jax2tf → TFJS. (4) JAX 직접 가능 (FastAPI, custom server) — 그냥 JAX. (5) ONNX 가 필요 — JAX → TF → ONNX. 변환마다 약간의 지원 손실 가능 — 중요한 op 가 모두 지원되는지 미리 검증.
한 가지 — 변환된 model 이 — 원본과 bit-identical 한지 항상 검증. 수치 차이가 미묘하게 다를 수 있어. unit test 로 random input 비교.