지금까지 본 4 가지 변환 (jit, grad, vmap, pmap) 을 합쳐서 의미 있는 첫 프로그램을 짜. 목표: noisy 한 데이터에 polynomial 을 fit 하는 미니 trainer.
import jax
import jax.numpy as jnp
from jax import random
# 1. 합성 데이터 만들기
key = random.PRNGKey(42)
key, x_key, noise_key = random.split(key, 3)
x = random.uniform(x_key, (100,), minval=-5, maxval=5)
true_a, true_b, true_c = 2.0, -1.5, 3.0
y = true_a * x**2 + true_b * x + true_c + 0.5 * random.normal(noise_key, (100,))
# 2. model — 그냥 함수
def model(params, x):
a, b, c = params
return a * x**2 + b * x + c
# 3. loss — 또 함수
def loss(params, x, y):
pred = model(params, x)
return jnp.mean((pred - y) ** 2)
# 4. gradient + jit — 하나로 묶음
grad_fn = jax.jit(jax.grad(loss))
# 5. 학습 loop
params = jnp.array([0.0, 0.0, 0.0])
lr = 0.01
for step in range(2000):
g = grad_fn(params, x, y)
params = params - lr * g
print(f"최종 params: {params}")
print(f"정답: [{true_a}, {true_b}, {true_c}]")
🌱 quest 의 출발점
지금 이 30 줄짜리 trainer 를 익히는 게 — 뒤의 13 개 track 모든 것의 기반이야. neural network 도 결국 더 큰 함수일 뿐. 더 큰 params, 더 복잡한 model, optimizer 라이브러리 — 다 이 template 의 확장이지 새 패러다임이 아니야.
Code
import jax
import jax.numpy as jnp
# 1. Generate synthetic data
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
X = jax.random.normal(subkey, shape=(100, 3)) # 100 samples, 3 features
true_w = jnp.array([2.0, -1.0, 0.5])
y = X @ true_w + 0.1 * jax.random.normal(key, shape=(100,))
# 2. Define the model and loss as pure functions
def predict(params, x):
return jnp.dot(x, params)
def loss_fn(params, X, y):
preds = predict(params, X)
return jnp.mean((preds - y) ** 2)
# 3. Compile the gradient computation
@jax.jit
def update(params, X, y, lr=0.1):
grads = jax.grad(loss_fn)(params, X, y)
return params - lr * grads # Simple gradient descent
# 4. Train
params = jnp.zeros(3) # Start from zeros
for step in range(100):
params = update(params, X, y)
if step % 20 == 0:
current_loss = loss_fn(params, X, y)
print(f"Step {step}, Loss: {current_loss:.4f}")
print(f"Learned params: {params}")
print(f"True params: {true_w}")
# Per-example gradients: gradient of loss for EACH data point
# In PyTorch, this requires special tricks. In JAX, it's one line.
def single_loss(params, x, y):
"""Loss for a single example."""
pred = jnp.dot(x, params)
return (pred - y) ** 2
# vmap over the data dimensions (axis 0 of x and y), not params
per_example_grad_fn = jax.vmap(jax.grad(single_loss), in_axes=(None, 0, 0))
per_example_grads = per_example_grad_fn(params, X, y)
print(per_example_grads.shape) # (100, 3) — one gradient per example
from-scratch script — y = ax² + bx + c 를 100 noisy point 에 fit, jax.grad + 수동 gradient descent 1000 step. jit-compiled vs eager 의 step time 비교. 수렴까지 — 학습된 coefficient 와 ground truth 같이 출력.
Progress
Progress is local-only — sign in to sync across devices.