JAX 가 등장하기 전에 Autograd 라는 작은 Python 라이브러리가 있었어. 평범한 Python + NumPy 코드를 그대로 쓰면서 자동으로 gradient 를 뽑아내는 도구였지. "수식 그대로 쓰면 미분이 공짜로 따라온다" — 우아했어. 근데 CPU 전용이고, compile 이 안 되고, GPU/TPU 같은 모던 가속기로는 못 갔어.
2018 년에 Google Brain 팀 — Matt Johnson, Roy Frostig, Dougal Maclaurin, Chris Leary 가 모여서 Autograd 의 후계자를 만들었어. 야망이 컸어 — Autograd 의 "Python 함수 그대로 미분" 단순함을 유지하면서, 구글의 XLA compiler 까지 얹어서 GPU/TPU 어디서든 빠르게 돌리겠다는 거. 그게 JAX. 이름은 원래 "Just After eXecution" 줄임말이었는데, 지금은 그냥 JAX 야.
2018 년 말 오픈소스로 풀리자마자 연구자들이 미쳐서 달려들었어. TensorFlow 의 graph 빌드 의식 같은 것도 없고, PyTorch 의 OOP layer 도 없는 — JAX 는 다른 거 제안해. composable function transformations 야. NumPy 처럼 평범한 함수 쓰고, 그 함수를 wrapping 으로 변형해 — compile, differentiate, vectorize, parallelize. 다 함수 한 번 감싸는 걸로.
💡 왜 이게 중요한가
JAX 는 "또 하나의 deep learning framework" 로 만들어진 게 아니야. numerical computing system 인데 마침 ML 도 잘하는 거. 이 차이가 PyTorch/TensorFlow 와 다른 느낌을 주는 이유야 — JAX 는 layer 쌓는 사람이 아니라 수학적 함수로 사고하는 연구자를 위한 도구거든.
지금 JAX 는 Google DeepMind 의 1 차 연구 framework 야. AlphaFold (단백질 구조 예측), Gemini 백엔드 연구, 수많은 논문이 JAX 위에서 돌아갔어. Stanford, MIT, Berkeley, MILA 같은 데서도 과학 컴퓨팅, 물리 시뮬레이션, Bayesian inference, 강화학습 다 JAX 로 해.
JAX 첫 만남이 어떤 모양이냐면:
import jax
import jax.numpy as jnp
# 평범한 함수 — class 도 decorator 도 필요 없음
def f(x):
return jnp.sum(x ** 2)
x = jnp.array([1.0, 2.0, 3.0])
# 그냥 호출
print(f(x)) # 14.0
# 이제 gradient — 함수 한 번 감싸기만 하면 끝
grad_f = jax.grad(f)
print(grad_f(x)) # [2. 4. 6.]
# Compile 해서 빠르게
fast_f = jax.jit(f)
print(fast_f(x)) # 14.0 (다음 호출부터 훨씬 빠름)
이게 JAX 의 약속이야. 함수를 쓴다, JAX 가 그 함수를 변형하는 도구를 준다. framework class 상속받을 필요도, 특수 tensor 타입 익힐 필요도 없어. 함수 in, 함수 out — 그게 다야.