C.W.K.
Stream
Lesson 06 of 06 · published

첫 JAX 프로그램: 한 번에 다 짜기

~10 min · origins, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

지금까지 본 4 가지 변환 (jit, grad, vmap, pmap) 을 합쳐서 의미 있는 첫 프로그램을 짜. 목표: noisy 한 데이터에 polynomial 을 fit 하는 미니 trainer.

import jax
import jax.numpy as jnp
from jax import random

# 1. 합성 데이터 만들기
key = random.PRNGKey(42)
key, x_key, noise_key = random.split(key, 3)

x = random.uniform(x_key, (100,), minval=-5, maxval=5)
true_a, true_b, true_c = 2.0, -1.5, 3.0
y = true_a * x**2 + true_b * x + true_c + 0.5 * random.normal(noise_key, (100,))

# 2. model — 그냥 함수
def model(params, x):
    a, b, c = params
    return a * x**2 + b * x + c

# 3. loss — 또 함수
def loss(params, x, y):
    pred = model(params, x)
    return jnp.mean((pred - y) ** 2)

# 4. gradient + jit — 하나로 묶음
grad_fn = jax.jit(jax.grad(loss))

# 5. 학습 loop
params = jnp.array([0.0, 0.0, 0.0])
lr = 0.01
for step in range(2000):
    g = grad_fn(params, x, y)
    params = params - lr * g

print(f"최종 params: {params}")
print(f"정답:       [{true_a}, {true_b}, {true_c}]")

🌱 quest 의 출발점

지금 이 30 줄짜리 trainer 를 익히는 게 — 뒤의 13 개 track 모든 것의 기반이야. neural network 도 결국 더 큰 함수일 뿐. 더 큰 params, 더 복잡한 model, optimizer 라이브러리 — 다 이 template 의 확장이지 새 패러다임이 아니야.

Code

import jax
import jax.numpy as jnp

# 1. Generate synthetic data
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
X = jax.random.normal(subkey, shape=(100, 3))  # 100 samples, 3 features

true_w = jnp.array([2.0, -1.0, 0.5])
y = X @ true_w + 0.1 * jax.random.normal(key, shape=(100,))

# 2. Define the model and loss as pure functions
def predict(params, x):
    return jnp.dot(x, params)

def loss_fn(params, X, y):
    preds = predict(params, X)
    return jnp.mean((preds - y) ** 2)

# 3. Compile the gradient computation
@jax.jit
def update(params, X, y, lr=0.1):
    grads = jax.grad(loss_fn)(params, X, y)
    return params - lr * grads  # Simple gradient descent

# 4. Train
params = jnp.zeros(3)  # Start from zeros
for step in range(100):
    params = update(params, X, y)
    if step % 20 == 0:
        current_loss = loss_fn(params, X, y)
        print(f"Step {step}, Loss: {current_loss:.4f}")

print(f"Learned params: {params}")
print(f"True params:    {true_w}")
# Per-example gradients: gradient of loss for EACH data point
# In PyTorch, this requires special tricks. In JAX, it's one line.

def single_loss(params, x, y):
    """Loss for a single example."""
    pred = jnp.dot(x, params)
    return (pred - y) ** 2

# vmap over the data dimensions (axis 0 of x and y), not params
per_example_grad_fn = jax.vmap(jax.grad(single_loss), in_axes=(None, 0, 0))

per_example_grads = per_example_grad_fn(params, X, y)
print(per_example_grads.shape)  # (100, 3) — one gradient per example

External links

Exercise

from-scratch script — y = ax² + bx + c 를 100 noisy point 에 fit, jax.grad + 수동 gradient descent 1000 step. jit-compiled vs eager 의 step time 비교. 수렴까지 — 학습된 coefficient 와 ground truth 같이 출력.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.