JAX 의 모든 기본 도구 — jnp, grad, jit, value_and_grad, pytree — 를 한 번에 써서 실전 trainer 를 작성. 이게 quest 의 가장 중요한 한 lesson.
import jax
import jax.numpy as jnp
from jax import random
import time
# ============ 1. 합성 데이터 ============
key = random.PRNGKey(42)
key, x_key, n_key = random.split(key, 3)
# y = 3 * x_0 - 1.5 * x_1 + 0.5 + noise
true_w = jnp.array([3.0, -1.5])
true_b = 0.5
N = 1000
X = random.uniform(x_key, (N, 2), minval=-2, maxval=2)
noise = 0.1 * random.normal(n_key, (N,))
y = X @ true_w + true_b + noise
# ============ 2. model 정의 ============
def init_params(key, n_features):
'''Xavier-like 초기화'''
return {
"w": random.normal(key, (n_features,)) * jnp.sqrt(1 / n_features),
"b": jnp.zeros(()),
}
def predict(params, x):
return x @ params["w"] + params["b"]
def loss_fn(params, x, y):
pred = predict(params, x)
return jnp.mean((pred - y) ** 2)
# ============ 3. 학습 step ============
@jax.jit
def train_step(params, x, y, lr):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
return new_params, loss
# ============ 4. 학습 loop ============
key, init_key = random.split(key)
params = init_params(init_key, n_features=2)
print(f"초기 params: w={params['w']}, b={params['b']}")
print(f"정답: w={true_w}, b={true_b}\n")
t = time.time()
for step in range(2000):
params, loss = train_step(params, X, y, lr=0.01)
if step % 200 == 0:
print(f"step {step:4d} loss={loss:.4f} "
f"w={params['w']} b={params['b']:.4f}")
print(f"\n학습 시간: {time.time()-t:.2f}s")
# ============ 5. 검증 ============
print(f"\n최종 w: {params['w']}, 정답 {true_w}")
print(f"최종 b: {params['b']:.4f}, 정답 {true_b}")
# 새 데이터로 평가
key, eval_key = random.split(key)
X_test = random.uniform(eval_key, (100, 2), minval=-2, maxval=2)
y_test = X_test @ true_w + true_b
pred_test = predict(params, X_test)
print(f"\ntest MAE: {jnp.mean(jnp.abs(pred_test - y_test)):.4f}")
출력 예:
초기 params: w=[0.55 0.30], b=0.0
정답: w=[ 3. -1.5], b=0.5
step 0 loss=8.5394 w=[ 0.61 -0.10] b=0.0240
step 200 loss=0.4128 w=[ 2.14 -1.00] b=0.4137
step 400 loss=0.0356 w=[ 2.84 -1.42] b=0.4894
step 600 loss=0.0107 w=[ 2.97 -1.49] b=0.4992
step 800 loss=0.0102 w=[ 3.00 -1.50] b=0.4999
...
최종 w: [3.00 -1.50], 정답 [3.0 -1.5]
최종 b: 0.5000, 정답 0.5
test MAE: 0.0021
🌱 모든 후속 quest 의 template
이 30 줄 trainer 의 구조가 — Track 10 에서 Flax NNX 로 NN 만들 때도, Track 11 에서 Optax 적용할 때도, Track 7 에서 multi-GPU 로 확장할 때도 — 그대로 유지돼. params (pytree) → predict (함수) → loss (함수) → grad → update → 반복. 변하는 건 model 의 복잡도뿐이야. 형태는 같아.
이 trainer 에서 추가 실험:
- learning rate 를 0.001, 0.1 로 바꾸고 수렴 속도 비교
- train_step 호출에
.block_until_ready()넣고 정확한 시간 측정 - params 를 dict 대신 NamedTuple 로 바꿔 보기 — pytree 가 임의 구조라는 걸 확인
jax.grad대신jax.value_and_grad(..., has_aux=True)로 metric 도 함께 반환
한 가지 — JAX 에서 학습 코드를 짤 때 가장 먼저 자문하는 질문: "이 함수는 pure 한가? jit 가능한가?" 둘 다 yes 면 — 다음은 거의 자동으로 풀려.