이론은 그만 — 실제 NumPy 코드를 JAX 로 옮기는 과정을 보자. K-means clustering 한 step.
# BEFORE — 평범한 NumPy
import numpy as np
def kmeans_step_np(X, centroids):
dists = np.linalg.norm(X[:, None] - centroids[None, :], axis=2)
labels = np.argmin(dists, axis=1)
new_centroids = np.zeros_like(centroids)
for k in range(len(centroids)):
mask = labels == k
if mask.any():
new_centroids[k] = X[mask].mean(axis=0)
return new_centroids, labels
JAX 화 — 단계적으로:
import jax
import jax.numpy as jnp
@jax.jit
def kmeans_step(X, centroids):
dists = jnp.linalg.norm(X[:, None] - centroids[None, :], axis=2)
labels = jnp.argmin(dists, axis=1)
one_hot = jax.nn.one_hot(labels, num_classes=centroids.shape[0])
counts = one_hot.sum(axis=0)
sums = one_hot.T @ X
new_centroids = sums / jnp.maximum(counts[:, None], 1)
return new_centroids, labels
📋 NumPy → JAX 체크리스트
(1) a[i] = v → a.at[i].set(v). (2) Python for → jnp.where / matrix ops / jax.lax.scan. (3) if on traced → jnp.where / jax.lax.cond. (4) random — jax.random + key. (5) jit 으로 wrap. (6) jax.devices() 로 hardware 확인.
Code import numpy as np
def softmax_numpy(x):
"""Numerically stable softmax."""
x_max = np.max(x, axis=-1, keepdims=True)
exp_x = np.exp(x - x_max)
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
def forward_numpy(params, x):
"""Two-layer neural network."""
w1, b1, w2, b2 = params
# Layer 1
h = np.dot(x, w1) + b1
# ReLU — uses in-place clipping in idiomatic NumPy
h = np.maximum(h, 0)
# Layer 2
logits = np.dot(h, w2) + b2
return softmax_numpy(logits)
# Initialize
np.random.seed(42)
w1 = np.random.randn(4, 8) * 0.1
b1 = np.zeros(8)
w2 = np.random.randn(8, 3) * 0.1
b2 = np.zeros(3)
params = [w1, b1, w2, b2]
x = np.random.randn(16, 4) # 16 samples, 4 features
probs = forward_numpy(params, x)
print(probs.shape) # (16, 3)
print(probs.sum(axis=-1)) # All 1.0import jax
import jax.numpy as jnp
def softmax_jax(x):
"""Numerically stable softmax — identical logic."""
x_max = jnp.max(x, axis=-1, keepdims=True)
exp_x = jnp.exp(x - x_max)
return exp_x / jnp.sum(exp_x, axis=-1, keepdims=True)
def forward_jax(params, x):
"""Two-layer neural network — pure function."""
w1, b1, w2, b2 = params
h = jnp.dot(x, w1) + b1
h = jnp.maximum(h, 0)
logits = jnp.dot(h, w2) + b2
return softmax_jax(logits)
# Initialize with JAX's explicit PRNG
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 4)
w1 = jax.random.normal(keys[0], (4, 8)) * 0.1
b1 = jnp.zeros(8)
w2 = jax.random.normal(keys[1], (8, 3)) * 0.1
b2 = jnp.zeros(3)
params = (w1, b1, w2, b2) # Tuple, not list (for pytree compatibility)
x = jax.random.normal(keys[2], (16, 4))
probs = forward_jax(params, x)
print(probs.shape) # (16, 3)# JIT compile for speed
fast_forward = jax.jit(forward_jax)
probs = fast_forward(params, x) # First call compiles; subsequent calls are fast
# Get gradients of a loss function
def loss_fn(params, x, targets):
probs = forward_jax(params, x)
# Cross-entropy loss
return -jnp.mean(jnp.sum(targets * jnp.log(probs + 1e-8), axis=-1))
# One-hot targets
targets = jax.nn.one_hot(jnp.array([0, 1, 2, 0, 1, 2, 0, 1,
2, 0, 1, 2, 0, 1, 2, 0]), 3)
# Gradient with respect to params
grads = jax.grad(loss_fn)(params, x, targets)
print(type(grads)) # tuple — same structure as params!
print(grads[0].shape) # (4, 8) — gradient for w1