C.W.K.
Stream
Lesson 02 of 06 · published

핵심 명제: 합성 가능한 함수 변환

~10 min · origins, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

대부분의 ML framework 는 객체 중심이야. model object 만들고, 거기에 데이터 흘려보내고, 메서드 호출. JAX 는 정반대로 가. 핵심 아이디어가 composable function transformations — 함수를 인자로 받아서 새로운 변환된 함수를 돌려주는 higher-order 함수.

4 대 변환 — JAX 의 "four pillars":

  • jax.jit — Compile: 함수를 trace 해서 XLA compiler 로 최적화된 머신 코드로. 첫 호출은 compile 때문에 느리지만, 다음 호출부터는 극적으로 빠름.
  • jax.grad — Differentiate: scalar 반환 함수를 받아서 gradient 계산하는 새 함수를 돌려줌.
  • jax.vmap — Vectorize: 한 example 에 작동하는 함수를 batch 에 작동하도록 변환.
  • jax.pmap — Parallelize: 여러 device (GPU/TPU) 에 함수를 동시에 돌림.

핵심 단어가 composable 이야. 자유롭게 stack 가능:

import jax
import jax.numpy as jnp

def loss_fn(params, x, y):
    predictions = jnp.dot(x, params)
    return jnp.mean((predictions - y) ** 2)

# 합성: 미분 → compile
fast_grad = jax.jit(jax.grad(loss_fn))

# 또는: per-example gradient → compile
per_example_grads = jax.jit(jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0)))

💡 왜 이게 중요한가

각 변환은 자체로 유용해. 근데 진짜 마법은 합성에서 나와. jit(vmap(grad(f))) — 이걸 PyTorch 에서 하려면 코드 한참 다시 짜야 돼. JAX 에선 한 줄. 변환을 1 등 시민으로 만든 게 — 이게 differential programming 의 패러다임 전환이야.

Code

import jax
import jax.numpy as jnp

def loss_fn(params, x, y):
    predictions = jnp.dot(x, params)
    return jnp.mean((predictions - y) ** 2)

# Compose transformations: differentiate, then compile
fast_grad = jax.jit(jax.grad(loss_fn))

# Or: vectorize per-example gradients, then compile
per_example_grads = jax.jit(jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0)))

params = jnp.array([1.0, 2.0])
x = jnp.array([[1.0, 0.5], [0.3, 0.8], [0.9, 0.1]])
y = jnp.array([1.5, 1.0, 0.8])

# Get compiled gradient
print(fast_grad(params, x, y))

# Get per-example gradients (one gradient vector per data point!)
print(per_example_grads(params, x, y))

External links

Exercise

loss_fn 예제 사용. jax.jit(jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))) 합성. batch size 1, 100, 10000 으로 측정. 합성 순서가 중요한 거 — jax.vmap(jax.jit(jax.grad(...))) 도 시도해서 차이 확인.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.