실제로 마주칠 흔한 impure 패턴 → 깨끗한 JAX 형태로 옮기는 연습이야.
케이스 1: 학습 step counter
# BEFORE
class Trainer:
def __init__(self):
self.step = 0
self.loss_history = []
def train_step(self, params, x, y):
self.step += 1
loss = compute_loss(params, x, y)
self.loss_history.append(loss)
new_params = params - 0.01 * jax.grad(compute_loss)(params, x, y)
return new_params
# 문제: step / loss_history 가 self mutation. jit 으로 못 감쌈.
# AFTER — state 를 명시적으로
@jax.jit
def train_step(state, params, x, y):
loss, grads = jax.value_and_grad(compute_loss)(params, x, y)
new_params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
new_state = {
"step": state["step"] + 1,
"last_loss": loss,
}
return new_state, new_params
# 사용
state = {"step": 0, "last_loss": 0.0}
for x, y in batches:
state, params = train_step(state, params, x, y)
# loss_history 는 호출 측에서 모음 (Python list 는 jit 밖에서 OK)
loss_history.append(float(state["last_loss"]))
케이스 2: dropout — random state 의 명시화
# BEFORE
def dropout(x, p=0.5):
mask = np.random.rand(*x.shape) > p # ❌ global random
return x * mask / (1 - p)
# AFTER — key 를 인자로
def dropout(x, key, p=0.5):
mask = jax.random.bernoulli(key, 1 - p, x.shape)
return x * mask / (1 - p)
# 사용
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
y = dropout(x, subkey, p=0.5)
케이스 3: lookup cache
# BEFORE
embedding_cache = {}
def get_embedding(token_id):
if token_id not in embedding_cache:
embedding_cache[token_id] = expensive_compute(token_id)
return embedding_cache[token_id]
# AFTER — pre-computed table 을 함수 input 으로
def get_embedding(table, token_id):
return table[token_id]
# 또는 vocabulary 전체를 한 번에 처리
embeddings = jax.vmap(expensive_compute)(jnp.arange(vocab_size))
# 후속 lookup 은 그냥 indexing
def lookup(token_ids):
return embeddings[token_ids]
케이스 4: layer 의 mutable attr
# BEFORE (PyTorch 식)
class BatchNorm:
def __init__(self, dim):
self.running_mean = np.zeros(dim)
self.running_var = np.ones(dim)
def __call__(self, x, training=True):
if training:
mean = x.mean(0); var = x.var(0)
self.running_mean = 0.9 * self.running_mean + 0.1 * mean # ❌ mutation
self.running_var = 0.9 * self.running_var + 0.1 * var
# ...
# AFTER — running stats 를 in/out 으로
def batch_norm(x, params, stats, training=True):
'''params: gamma, beta. stats: running_mean, running_var.'''
if training:
mean = x.mean(0); var = x.var(0)
new_stats = {
"mean": 0.9 * stats["mean"] + 0.1 * mean,
"var": 0.9 * stats["var"] + 0.1 * var,
}
x_norm = (x - mean) / jnp.sqrt(var + 1e-5)
else:
x_norm = (x - stats["mean"]) / jnp.sqrt(stats["var"] + 1e-5)
new_stats = stats
return params["gamma"] * x_norm + params["beta"], new_stats
📐 일관된 패턴: state-in / state-out
JAX 어디서나 보이는 모양 — 함수가 state 를 받아서 다음 state 를 내놓음. (state, x) → (new_state, output). PyTorch 의 self.x 가 함수의 input/output 으로 분해되는 거. 처음엔 verbose 하지만 — 모든 게 보이게 됨. 추적, 디버깅, jit, vmap 다 자연스러움.
리팩터링 마지막 단계: 같은 함수가 jit 안에서 도는지 확인.
train_step = jax.jit(train_step)
dropout = jax.jit(dropout, static_argnames=("p",))
batch_norm = jax.jit(batch_norm, static_argnames=("training",))
compile 한 번, 호출 N 번 — JAX 식 효율의 정석.