C.W.K.
Stream
Lesson 05 of 05 · published

실전: Linear Regression 을 처음부터

~11 min · grad, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 의 모든 기본 도구 — jnp, grad, jit, value_and_grad, pytree — 를 한 번에 써서 실전 trainer 를 작성. 이게 quest 의 가장 중요한 한 lesson.

import jax
import jax.numpy as jnp
from jax import random
import time

# ============ 1. 합성 데이터 ============
key = random.PRNGKey(42)
key, x_key, n_key = random.split(key, 3)

# y = 3 * x_0 - 1.5 * x_1 + 0.5 + noise
true_w = jnp.array([3.0, -1.5])
true_b = 0.5

N = 1000
X = random.uniform(x_key, (N, 2), minval=-2, maxval=2)
noise = 0.1 * random.normal(n_key, (N,))
y = X @ true_w + true_b + noise

# ============ 2. model 정의 ============
def init_params(key, n_features):
    '''Xavier-like 초기화'''
    return {
        "w": random.normal(key, (n_features,)) * jnp.sqrt(1 / n_features),
        "b": jnp.zeros(()),
    }

def predict(params, x):
    return x @ params["w"] + params["b"]

def loss_fn(params, x, y):
    pred = predict(params, x)
    return jnp.mean((pred - y) ** 2)

# ============ 3. 학습 step ============
@jax.jit
def train_step(params, x, y, lr):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

# ============ 4. 학습 loop ============
key, init_key = random.split(key)
params = init_params(init_key, n_features=2)

print(f"초기 params: w={params['w']}, b={params['b']}")
print(f"정답:       w={true_w}, b={true_b}\n")

t = time.time()
for step in range(2000):
    params, loss = train_step(params, X, y, lr=0.01)
    if step % 200 == 0:
        print(f"step {step:4d}  loss={loss:.4f}  "
              f"w={params['w']}  b={params['b']:.4f}")
print(f"\n학습 시간: {time.time()-t:.2f}s")

# ============ 5. 검증 ============
print(f"\n최종 w: {params['w']}, 정답 {true_w}")
print(f"최종 b: {params['b']:.4f}, 정답 {true_b}")

# 새 데이터로 평가
key, eval_key = random.split(key)
X_test = random.uniform(eval_key, (100, 2), minval=-2, maxval=2)
y_test = X_test @ true_w + true_b
pred_test = predict(params, X_test)
print(f"\ntest MAE: {jnp.mean(jnp.abs(pred_test - y_test)):.4f}")

출력 예:

초기 params: w=[0.55 0.30], b=0.0
정답:       w=[ 3.  -1.5], b=0.5

step    0  loss=8.5394  w=[ 0.61 -0.10]  b=0.0240
step  200  loss=0.4128  w=[ 2.14 -1.00]  b=0.4137
step  400  loss=0.0356  w=[ 2.84 -1.42]  b=0.4894
step  600  loss=0.0107  w=[ 2.97 -1.49]  b=0.4992
step  800  loss=0.0102  w=[ 3.00 -1.50]  b=0.4999
...
최종 w: [3.00 -1.50], 정답 [3.0 -1.5]
최종 b: 0.5000, 정답 0.5
test MAE: 0.0021

🌱 모든 후속 quest 의 template

이 30 줄 trainer 의 구조가 — Track 10 에서 Flax NNX 로 NN 만들 때도, Track 11 에서 Optax 적용할 때도, Track 7 에서 multi-GPU 로 확장할 때도 — 그대로 유지돼. params (pytree) → predict (함수) → loss (함수) → grad → update → 반복. 변하는 건 model 의 복잡도뿐이야. 형태는 같아.

이 trainer 에서 추가 실험:

  • learning rate 를 0.001, 0.1 로 바꾸고 수렴 속도 비교
  • train_step 호출에 .block_until_ready() 넣고 정확한 시간 측정
  • params 를 dict 대신 NamedTuple 로 바꿔 보기 — pytree 가 임의 구조라는 걸 확인
  • jax.grad 대신 jax.value_and_grad(..., has_aux=True) 로 metric 도 함께 반환

한 가지 — JAX 에서 학습 코드를 짤 때 가장 먼저 자문하는 질문: "이 함수는 pure 한가? jit 가능한가?" 둘 다 yes 면 — 다음은 거의 자동으로 풀려.

Code

import jax
import jax.numpy as jnp

# ============================================
# 1. Generate synthetic data
# ============================================
key = jax.random.PRNGKey(0)
k1, k2, k3 = jax.random.split(key, 3)

n_samples, n_features = 500, 5
true_weights = jax.random.normal(k1, (n_features,))
true_bias = jnp.array(0.7)

X = jax.random.normal(k2, (n_samples, n_features))
noise = 0.1 * jax.random.normal(k3, (n_samples,))
y = X @ true_weights + true_bias + noise

# ============================================
# 2. Define model and loss (pure functions)
# ============================================
def predict(params, x):
    w, b = params
    return x @ w + b

def mse_loss(params, x, y):
    preds = predict(params, x)
    return jnp.mean((preds - y) ** 2)

# ============================================
# 3. Define training step (compiled)
# ============================================
@jax.jit
def train_step(params, x, y, lr):
    loss_val, grads = jax.value_and_grad(mse_loss)(params, x, y)
    w, b = params
    gw, gb = grads
    new_params = (w - lr * gw, b - lr * gb)
    return new_params, loss_val

# ============================================
# 4. Initialize and train
# ============================================
params = (jnp.zeros(n_features), jnp.array(0.0))
learning_rate = 0.1

for epoch in range(200):
    params, loss = train_step(params, X, y, learning_rate)
    if epoch % 40 == 0:
        print(f"Epoch {epoch:3d} | Loss: {loss:.6f}")

# ============================================
# 5. Compare results
# ============================================
learned_w, learned_b = params
print(f"\\nTrue weights:    {true_weights}")
print(f"Learned weights: {learned_w}")
print(f"True bias:       {true_bias}")
print(f"Learned bias:    {learned_b:.4f}")
print(f"Weight error:    {jnp.linalg.norm(true_weights - learned_w):.6f}")
import optax

optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)

@jax.jit
def train_step_optax(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(mse_loss)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

External links

Exercise

선형 회귀 from-scratch 구현: y = 3x + 1 + noise 합성, mse loss, jit + grad + 손수 SGD loop 1000 step. 학습된 slope+intercept 출력. JAX 의 가장 작은 end-to-end trainer — 후속 quest 모두의 template.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.