jax.lax.scan — JAX 에서 sequential 연산의 정석. RNN, training loop, 누적 계산 등 — 모든 "한 단계 결과를 다음 단계가 사용" 하는 패턴에 사용.
기본 사용
import jax
import jax.numpy as jnp
def step(carry, x):
'''carry: 이전 상태, x: 현재 입력'''
new_carry = carry + x
output = new_carry * 2
return new_carry, output
# scan: (carry, [x0, x1, ...]) → (final_carry, [y0, y1, ...])
xs = jnp.array([1., 2., 3., 4., 5.])
init_carry = 0.0
final_carry, ys = jax.lax.scan(step, init_carry, xs)
print(f"final_carry: {final_carry}") # 15
print(f"ys: {ys}") # [2, 6, 12, 20, 30]
Python for-loop 와 같은 의미:
carry = init_carry
ys = []
for x in xs:
carry, y = step(carry, x)
ys.append(y)
ys = jnp.stack(ys)
다른 점 — scan 은 — single XLA loop 로 컴파일. unrolling 안 함. 길이 100, 1000, 10000 다 같은 IR 사이즈.
RNN 구현
def rnn_cell(params, h, x):
'''단일 step'''
h_new = jnp.tanh(params["W_h"] @ h + params["W_x"] @ x + params["b"])
return h_new
def rnn_forward(params, h0, xs):
'''xs: (T, input_dim) — sequence'''
def step(h, x):
h_new = rnn_cell(params, h, x)
return h_new, h_new # carry: hidden, output: hidden 도
final_h, hs = jax.lax.scan(step, h0, xs)
return hs # (T, hidden_dim)
scan with input from each step + 외부 input
# 학습 loop 도 scan 으로
def train_step_for_scan(state, batch):
# batch 는 매 step 의 input
new_state, loss, _ = train_step(state, batch)
return new_state, loss
# 한 epoch — 100 step 의 batch 들
batches = get_all_batches() # shape: (100, B, ...)
final_state, losses = jax.lax.scan(train_step_for_scan, state, batches)
학습 step 이 매번 전혀 안 변하면 — Python for 보다 scan 으로 — host overhead 거의 0.
cumsum 같은 누적
# jnp 에 cumsum 있지만, 일반화된 누적 연산:
def custom_cumsum(xs):
def step(acc, x):
new_acc = acc + x
return new_acc, new_acc
_, cumulative = jax.lax.scan(step, 0.0, xs)
return cumulative
# exponential moving average
def ema(xs, alpha=0.1):
def step(prev, x):
new = alpha * x + (1 - alpha) * prev
return new, new
_, smooth = jax.lax.scan(step, xs[0], xs[1:])
return jnp.concatenate([xs[:1], smooth])
fori_loop / while_loop — 출력 누적 안 할 때
# fori_loop — 정해진 N 번 반복, 마지막 state 만
def newton_iteration(f, f_prime, x0, n_iter=20):
def body(i, x):
return x - f(x) / f_prime(x)
return jax.lax.fori_loop(0, n_iter, body, x0)
# while_loop — 조건 종료
def find_root(f, f_prime, x0, tol=1e-6, max_iter=100):
def cond(state):
x, i = state
return (jnp.abs(f(x)) > tol) & (i < max_iter)
def body(state):
x, i = state
return (x - f(x) / f_prime(x), i + 1)
final_x, _ = jax.lax.while_loop(cond, body, (x0, 0))
return final_x
📐 scan / fori_loop / while_loop 선택
(1) scan — 매 step 의 output 도 모음. RNN, training loop. (2) fori_loop — 정해진 N 번, output 누적 안 함. iteration. (3) while_loop — 조건 종료, dynamic 길이. iterative solver, BFS. 셋 다 jit 안에서 컴파일 단계 1 번. Python for-loop 보다 — 큰 N 에서 절대적으로 빠름.
한 가지 — scan 의 length 는 trace 시점에 정해져야 함. dynamic 길이는 — while_loop 또는 padded scan + masking.