C.W.K.
Stream
Lesson 06 of 06 · published

실전 예제: NumPy 코드를 JAX 로 옮기기

~11 min · numpy, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

이론은 그만 — 실제 NumPy 코드를 JAX 로 옮기는 과정을 보자. K-means clustering 한 step.

# BEFORE — 평범한 NumPy
import numpy as np

def kmeans_step_np(X, centroids):
    dists = np.linalg.norm(X[:, None] - centroids[None, :], axis=2)
    labels = np.argmin(dists, axis=1)

    new_centroids = np.zeros_like(centroids)
    for k in range(len(centroids)):
        mask = labels == k
        if mask.any():
            new_centroids[k] = X[mask].mean(axis=0)
    return new_centroids, labels

JAX 화 — 단계적으로:

import jax
import jax.numpy as jnp

@jax.jit
def kmeans_step(X, centroids):
    dists = jnp.linalg.norm(X[:, None] - centroids[None, :], axis=2)
    labels = jnp.argmin(dists, axis=1)

    one_hot = jax.nn.one_hot(labels, num_classes=centroids.shape[0])
    counts = one_hot.sum(axis=0)
    sums = one_hot.T @ X
    new_centroids = sums / jnp.maximum(counts[:, None], 1)
    return new_centroids, labels

📋 NumPy → JAX 체크리스트

(1) a[i] = v → a.at[i].set(v). (2) Python for → jnp.where / matrix ops / jax.lax.scan. (3) if on traced → jnp.where / jax.lax.cond. (4) random — jax.random + key. (5) jit 으로 wrap. (6) jax.devices() 로 hardware 확인.

Code

import numpy as np

def softmax_numpy(x):
    """Numerically stable softmax."""
    x_max = np.max(x, axis=-1, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

def forward_numpy(params, x):
    """Two-layer neural network."""
    w1, b1, w2, b2 = params
    # Layer 1
    h = np.dot(x, w1) + b1
    # ReLU — uses in-place clipping in idiomatic NumPy
    h = np.maximum(h, 0)
    # Layer 2
    logits = np.dot(h, w2) + b2
    return softmax_numpy(logits)

# Initialize
np.random.seed(42)
w1 = np.random.randn(4, 8) * 0.1
b1 = np.zeros(8)
w2 = np.random.randn(8, 3) * 0.1
b2 = np.zeros(3)
params = [w1, b1, w2, b2]

x = np.random.randn(16, 4)  # 16 samples, 4 features
probs = forward_numpy(params, x)
print(probs.shape)  # (16, 3)
print(probs.sum(axis=-1))  # All 1.0
import jax
import jax.numpy as jnp

def softmax_jax(x):
    """Numerically stable softmax — identical logic."""
    x_max = jnp.max(x, axis=-1, keepdims=True)
    exp_x = jnp.exp(x - x_max)
    return exp_x / jnp.sum(exp_x, axis=-1, keepdims=True)

def forward_jax(params, x):
    """Two-layer neural network — pure function."""
    w1, b1, w2, b2 = params
    h = jnp.dot(x, w1) + b1
    h = jnp.maximum(h, 0)
    logits = jnp.dot(h, w2) + b2
    return softmax_jax(logits)

# Initialize with JAX's explicit PRNG
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 4)
w1 = jax.random.normal(keys[0], (4, 8)) * 0.1
b1 = jnp.zeros(8)
w2 = jax.random.normal(keys[1], (8, 3)) * 0.1
b2 = jnp.zeros(3)
params = (w1, b1, w2, b2)  # Tuple, not list (for pytree compatibility)

x = jax.random.normal(keys[2], (16, 4))
probs = forward_jax(params, x)
print(probs.shape)  # (16, 3)
# JIT compile for speed
fast_forward = jax.jit(forward_jax)
probs = fast_forward(params, x)  # First call compiles; subsequent calls are fast

# Get gradients of a loss function
def loss_fn(params, x, targets):
    probs = forward_jax(params, x)
    # Cross-entropy loss
    return -jnp.mean(jnp.sum(targets * jnp.log(probs + 1e-8), axis=-1))

# One-hot targets
targets = jax.nn.one_hot(jnp.array([0, 1, 2, 0, 1, 2, 0, 1,
                                      2, 0, 1, 2, 0, 1, 2, 0]), 3)

# Gradient with respect to params
grads = jax.grad(loss_fn)(params, x, targets)
print(type(grads))  # tuple — same structure as params!
print(grads[0].shape)  # (4, 8) — gradient for w1

External links

Exercise

이전에 짠 작은 NumPy script (~50 줄) — array in-place mutation 의 loop 가 있는 거. lesson 의 process 따라 JAX 로 port: import 변경, mutation 제거, inner step jit. before/after time 측정. 가장 아팠던 부분 3 줄 메모.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.