C.W.K.
Stream
Lesson 01 of 05 · published

Production 용 JAX Model 내보내기

~8 min · ecosystem, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

research 코드가 잘 돌아도 — production 인프라가 JAX 가 아니면? 표준 답: jax2tf 로 TensorFlow SavedModel 출력.

import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tensorflow as tf

# JAX model
@jax.jit
def my_model(params, x):
    return jax.nn.softmax(x @ params["W"] + params["b"])

# JAX → TF function
tf_fn = jax2tf.convert(my_model, with_gradient=False)

# wrapper Module
class MyTfModule(tf.Module):
    def __init__(self, params):
        self.params = tf.nest.map_structure(tf.Variable, params)

    @tf.function(input_signature=[tf.TensorSpec((None, 784), tf.float32)])
    def serve(self, x):
        return tf_fn(self.params, x)

# 저장
module = MyTfModule(params)
tf.saved_model.save(module, "my_saved_model/")

저장 후 — 다른 곳에서 TF 로 로드:

loaded = tf.saved_model.load("my_saved_model/")
y = loaded.serve(x_input)   # JAX 가 없는 환경에서도 동작

TF Serving, TF Lite (mobile), TFJS (브라우저), Vertex AI 등 — 다 호환.

onnx export

JAX → ONNX 직접 변환은 — 여전히 (2026) 미숙. 전형적 경로 — JAX → TF → ONNX:

import tf2onnx

# 위에서 저장한 SavedModel 을 ONNX 로
spec = (tf.TensorSpec((None, 784), tf.float32, name="input"),)
model_proto, _ = tf2onnx.convert.from_keras(
    module,
    input_signature=spec,
    output_path="model.onnx",
)

ONNX 의 호환성이 점점 좋아지고 있어 — TensorRT, OpenVINO, CoreML 등 다양한 inference engine 으로.

StableHLO — 새로운 IR

2024 년부터 — Google 이 StableHLO 를 cross-framework IR 로 밀고 있음. JAX, TF, PyTorch 모두 export 가능. inference 인프라 (XLA 기반) 가 직접 받음:

from jax.experimental import export

# JAX 함수 → StableHLO
exported = export.export(my_model)(params, jax.ShapeDtypeStruct((None, 784), jnp.float32))
serialized = exported.serialize()

# 다른 곳에서 로드
loaded = export.deserialize(serialized)
y = loaded.call(params, x_input)

장기적으론 — StableHLO 가 ONNX 의 자리를 대체할 가능성. 단기 — TF SavedModel 이 가장 검증된 경로.

ONNX 우회 — JAX 직접 deploy

인프라가 JAX 가능하면 — 그냥 JAX 로 deploy:

# FastAPI server
from fastapi import FastAPI
import jax
import pickle

app = FastAPI()

# load params at startup
with open("params.pkl", "rb") as f:
    params = pickle.load(f)

# pre-compile
dummy_input = jnp.zeros((1, 784))
jit_predict = jax.jit(my_model)
_ = jit_predict(params, dummy_input)   # warm-up

@app.post("/predict")
def predict(data: dict):
    x = jnp.array(data["input"])
    y = jit_predict(params, x)
    return {"output": y.tolist()}

JAX 가 production 에서 잘 도는 게 — Google 내부에선 일반적. 외부에선 TF / ONNX 변환이 더 흔한 게 인프라 친화도 차이.

🚀 deploy 전략 결정

(1) 인프라가 TF/Vertex AI 위주 — jax2tf → SavedModel. (2) 모바일 / edge — jax2tf → TF Lite. (3) 브라우저 — jax2tf → TFJS. (4) JAX 직접 가능 (FastAPI, custom server) — 그냥 JAX. (5) ONNX 가 필요 — JAX → TF → ONNX. 변환마다 약간의 지원 손실 가능 — 중요한 op 가 모두 지원되는지 미리 검증.

한 가지 — 변환된 model 이 — 원본과 bit-identical 한지 항상 검증. 수치 차이가 미묘하게 다를 수 있어. unit test 로 random input 비교.

Code

import jax
from jax import export
import jax.numpy as jnp
import numpy as np

# 1. Create a JIT-transformed function
@jax.jit
def predict(params, x):
    h = jax.nn.relu(x @ params['w1'] + params['b1'])
    return h @ params['w2'] + params['b2']

# 2. Define input shapes for export
params_shapes = {
    'w1': jax.ShapeDtypeStruct((784, 256), jnp.float32),
    'b1': jax.ShapeDtypeStruct((256,), jnp.float32),
    'w2': jax.ShapeDtypeStruct((256, 10), jnp.float32),
    'b2': jax.ShapeDtypeStruct((10,), jnp.float32),
}
x_shape = jax.ShapeDtypeStruct((1, 784), jnp.float32)

# 3. Export to StableHLO
exported = export.export(predict)(params_shapes, x_shape)

# 4. Get the StableHLO module (MLIR text)
stablehlo_module = exported.mlir_module()

# 5. Serialize for later use
serialized = export.export(predict)(params_shapes, x_shape).serialize()
# Can be saved to disk and loaded in a different process/language
# Export with dynamic batch dimension
scope = export.SymbolicScope()
dynamic_x = jax.ShapeDtypeStruct(
    export.symbolic_shape("batch, 784", scope=scope),
    jnp.float32,
)

exported_dynamic = export.export(predict)(params_shapes, dynamic_x)
# Now the exported model accepts any batch size
# You can pack StableHLO into a TensorFlow SavedModel
# for serving with TensorFlow Serving
from jax.experimental.jax2tf import convert as jax2tf_convert

# Note: the recommended path is now:
# 1. Export to StableHLO via jax.export
# 2. Load StableHLO into TF SavedModel if needed for TF Serving
# 3. Or use StableHLO directly with XLA-compatible runtimes

External links

Exercise

학습된 model. jax2tf 로 TF SavedModel export. plain TensorFlow 에서 reload, 동일 prediction 검증. export story 가 — JAX 가 우리 회사에서 research-only 로 끝날지 결정.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.