C.W.K.
Stream
Lesson 05 of 05 · published

실전: Impure 코드 리팩터링

~11 min · purity, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

실제로 마주칠 흔한 impure 패턴 → 깨끗한 JAX 형태로 옮기는 연습이야.

케이스 1: 학습 step counter

# BEFORE
class Trainer:
    def __init__(self):
        self.step = 0
        self.loss_history = []

    def train_step(self, params, x, y):
        self.step += 1
        loss = compute_loss(params, x, y)
        self.loss_history.append(loss)
        new_params = params - 0.01 * jax.grad(compute_loss)(params, x, y)
        return new_params

# 문제: step / loss_history 가 self mutation. jit 으로 못 감쌈.
# AFTER — state 를 명시적으로
@jax.jit
def train_step(state, params, x, y):
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)
    new_params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
    new_state = {
        "step": state["step"] + 1,
        "last_loss": loss,
    }
    return new_state, new_params

# 사용
state = {"step": 0, "last_loss": 0.0}
for x, y in batches:
    state, params = train_step(state, params, x, y)

# loss_history 는 호출 측에서 모음 (Python list 는 jit 밖에서 OK)
loss_history.append(float(state["last_loss"]))

케이스 2: dropout — random state 의 명시화

# BEFORE
def dropout(x, p=0.5):
    mask = np.random.rand(*x.shape) > p   # ❌ global random
    return x * mask / (1 - p)
# AFTER — key 를 인자로
def dropout(x, key, p=0.5):
    mask = jax.random.bernoulli(key, 1 - p, x.shape)
    return x * mask / (1 - p)

# 사용
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
y = dropout(x, subkey, p=0.5)

케이스 3: lookup cache

# BEFORE
embedding_cache = {}
def get_embedding(token_id):
    if token_id not in embedding_cache:
        embedding_cache[token_id] = expensive_compute(token_id)
    return embedding_cache[token_id]
# AFTER — pre-computed table 을 함수 input 으로
def get_embedding(table, token_id):
    return table[token_id]

# 또는 vocabulary 전체를 한 번에 처리
embeddings = jax.vmap(expensive_compute)(jnp.arange(vocab_size))

# 후속 lookup 은 그냥 indexing
def lookup(token_ids):
    return embeddings[token_ids]

케이스 4: layer 의 mutable attr

# BEFORE (PyTorch 식)
class BatchNorm:
    def __init__(self, dim):
        self.running_mean = np.zeros(dim)
        self.running_var = np.ones(dim)

    def __call__(self, x, training=True):
        if training:
            mean = x.mean(0); var = x.var(0)
            self.running_mean = 0.9 * self.running_mean + 0.1 * mean  # ❌ mutation
            self.running_var = 0.9 * self.running_var + 0.1 * var
        # ...
# AFTER — running stats 를 in/out 으로
def batch_norm(x, params, stats, training=True):
    '''params: gamma, beta. stats: running_mean, running_var.'''
    if training:
        mean = x.mean(0); var = x.var(0)
        new_stats = {
            "mean": 0.9 * stats["mean"] + 0.1 * mean,
            "var": 0.9 * stats["var"] + 0.1 * var,
        }
        x_norm = (x - mean) / jnp.sqrt(var + 1e-5)
    else:
        x_norm = (x - stats["mean"]) / jnp.sqrt(stats["var"] + 1e-5)
        new_stats = stats
    return params["gamma"] * x_norm + params["beta"], new_stats

📐 일관된 패턴: state-in / state-out

JAX 어디서나 보이는 모양 — 함수가 state 를 받아서 다음 state 를 내놓음. (state, x) → (new_state, output). PyTorch 의 self.x 가 함수의 input/output 으로 분해되는 거. 처음엔 verbose 하지만 — 모든 게 보이게 됨. 추적, 디버깅, jit, vmap 다 자연스러움.

리팩터링 마지막 단계: 같은 함수가 jit 안에서 도는지 확인.

train_step = jax.jit(train_step)
dropout = jax.jit(dropout, static_argnames=("p",))
batch_norm = jax.jit(batch_norm, static_argnames=("training",))

compile 한 번, 호출 N 번 — JAX 식 효율의 정석.

Code

import numpy as np

# Global state — impure!
class RunningStats:
    def __init__(self):
        self.count = 0
        self.mean = 0.0
        self.variance = 0.0

    def update(self, batch):
        """Welford's online algorithm — mutates self."""
        for x in batch:
            self.count += 1
            delta = x - self.mean
            self.mean += delta / self.count
            delta2 = x - self.mean
            self.variance += delta * delta2

    def get_stats(self):
        return self.mean, self.variance / max(self.count - 1, 1)

# Usage: side effects everywhere
stats = RunningStats()
for _ in range(10):
    batch = np.random.randn(100)
    stats.update(batch)
    mean, var = stats.get_stats()
    print(f"Mean: {mean:.4f}, Var: {var:.4f}")
import jax
import jax.numpy as jnp

# State is a plain data structure (NamedTuple or dict)
from typing import NamedTuple

class StatsState(NamedTuple):
    count: jnp.ndarray
    mean: jnp.ndarray
    m2: jnp.ndarray  # Sum of squared deviations

def init_stats():
    """Create initial state."""
    return StatsState(
        count=jnp.array(0, dtype=jnp.int32),
        mean=jnp.array(0.0),
        m2=jnp.array(0.0),
    )

def update_stats(state, x):
    """Welford's algorithm — pure function, returns new state."""
    count = state.count + 1
    delta = x - state.mean
    mean = state.mean + delta / count
    delta2 = x - mean
    m2 = state.m2 + delta * delta2
    return StatsState(count=count, mean=mean, m2=m2)

def get_stats(state):
    """Extract mean and variance from state."""
    variance = jnp.where(state.count > 1, state.m2 / (state.count - 1), 0.0)
    return state.mean, variance

# Process a batch using jax.lax.scan (functional loop)
@jax.jit
def process_batch(state, batch):
    state = jax.lax.fori_loop(
        0, batch.shape[0],
        lambda i, s: update_stats(s, batch[i]),
        state
    )
    return state

# Usage: explicit state threading, no mutation
state = init_stats()
key = jax.random.PRNGKey(42)
for i in range(10):
    key, subkey = jax.random.split(key)
    batch = jax.random.normal(subkey, (100,))
    state = process_batch(state, batch)
    mean, var = get_stats(state)
    print(f"Step {i}: Mean: {mean:.4f}, Var: {var:.4f}")

External links

Exercise

lesson 3-2 의 list 에서 가장 심한 impure 함수. pure 형태로 refactor. jit 합성. 원본과 정확성 검증. 동료에게 경고할 한 마디 적기 — 그 메모가 refactor 자체보다 더 가치 있음.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.