C.W.K.
Stream
Lesson 01 of 06 · published

JAX 의 출생 신고

~8 min · origins, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 가 등장하기 전에 Autograd 라는 작은 Python 라이브러리가 있었어. 평범한 Python + NumPy 코드를 그대로 쓰면서 자동으로 gradient 를 뽑아내는 도구였지. "수식 그대로 쓰면 미분이 공짜로 따라온다" — 우아했어. 근데 CPU 전용이고, compile 이 안 되고, GPU/TPU 같은 모던 가속기로는 못 갔어.

2018 년에 Google Brain 팀 — Matt Johnson, Roy Frostig, Dougal Maclaurin, Chris Leary 가 모여서 Autograd 의 후계자를 만들었어. 야망이 컸어 — Autograd 의 "Python 함수 그대로 미분" 단순함을 유지하면서, 구글의 XLA compiler 까지 얹어서 GPU/TPU 어디서든 빠르게 돌리겠다는 거. 그게 JAX. 이름은 원래 "Just After eXecution" 줄임말이었는데, 지금은 그냥 JAX 야.

2018 년 말 오픈소스로 풀리자마자 연구자들이 미쳐서 달려들었어. TensorFlow 의 graph 빌드 의식 같은 것도 없고, PyTorch 의 OOP layer 도 없는 — JAX 는 다른 거 제안해. composable function transformations 야. NumPy 처럼 평범한 함수 쓰고, 그 함수를 wrapping 으로 변형해 — compile, differentiate, vectorize, parallelize. 다 함수 한 번 감싸는 걸로.

💡 왜 이게 중요한가

JAX 는 "또 하나의 deep learning framework" 로 만들어진 게 아니야. numerical computing system 인데 마침 ML 도 잘하는 거. 이 차이가 PyTorch/TensorFlow 와 다른 느낌을 주는 이유야 — JAX 는 layer 쌓는 사람이 아니라 수학적 함수로 사고하는 연구자를 위한 도구거든.

지금 JAX 는 Google DeepMind 의 1 차 연구 framework 야. AlphaFold (단백질 구조 예측), Gemini 백엔드 연구, 수많은 논문이 JAX 위에서 돌아갔어. Stanford, MIT, Berkeley, MILA 같은 데서도 과학 컴퓨팅, 물리 시뮬레이션, Bayesian inference, 강화학습 다 JAX 로 해.

JAX 첫 만남이 어떤 모양이냐면:

import jax
import jax.numpy as jnp

# 평범한 함수 — class 도 decorator 도 필요 없음
def f(x):
    return jnp.sum(x ** 2)

x = jnp.array([1.0, 2.0, 3.0])

# 그냥 호출
print(f(x))  # 14.0

# 이제 gradient — 함수 한 번 감싸기만 하면 끝
grad_f = jax.grad(f)
print(grad_f(x))  # [2. 4. 6.]

# Compile 해서 빠르게
fast_f = jax.jit(f)
print(fast_f(x))  # 14.0 (다음 호출부터 훨씬 빠름)

이게 JAX 의 약속이야. 함수를 쓴다, JAX 가 그 함수를 변형하는 도구를 준다. framework class 상속받을 필요도, 특수 tensor 타입 익힐 필요도 없어. 함수 in, 함수 out — 그게 다야.

Code

import jax
import jax.numpy as jnp

# A plain function — no classes, no decorators needed
def f(x):
    return jnp.sum(x ** 2)

x = jnp.array([1.0, 2.0, 3.0])

# Evaluate it
print(f(x))  # 14.0

# Now get its gradient — just wrap the function
grad_f = jax.grad(f)
print(grad_f(x))  # [2. 4. 6.]

# Compile it for speed
fast_f = jax.jit(f)
print(fast_f(x))  # 14.0 (but faster on subsequent calls)

External links

Exercise

JAX 설치 (CPU 도 OK), 이 lesson 의 3 줄 예제 실행. 그 다음 f(x) = sum(x**2) 를 f(x) = sum(jnp.sin(x)**2) 로 바꾸고 gradient 출력. grad 가 자동으로 작동하는 거 확인. my_first_jax.py 로 저장 — 나중에 다시 와.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.