C.W.K.
Stream
Lesson 03 of 05 · published

Scan Pattern 과 구조적 Control Flow

~9 min · advanced, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

jax.lax.scan — JAX 에서 sequential 연산의 정석. RNN, training loop, 누적 계산 등 — 모든 "한 단계 결과를 다음 단계가 사용" 하는 패턴에 사용.

기본 사용

import jax
import jax.numpy as jnp

def step(carry, x):
    '''carry: 이전 상태, x: 현재 입력'''
    new_carry = carry + x
    output = new_carry * 2
    return new_carry, output

# scan: (carry, [x0, x1, ...]) → (final_carry, [y0, y1, ...])
xs = jnp.array([1., 2., 3., 4., 5.])
init_carry = 0.0

final_carry, ys = jax.lax.scan(step, init_carry, xs)
print(f"final_carry: {final_carry}")   # 15
print(f"ys: {ys}")                      # [2, 6, 12, 20, 30]

Python for-loop 와 같은 의미:

carry = init_carry
ys = []
for x in xs:
    carry, y = step(carry, x)
    ys.append(y)
ys = jnp.stack(ys)

다른 점 — scan 은 — single XLA loop 로 컴파일. unrolling 안 함. 길이 100, 1000, 10000 다 같은 IR 사이즈.

RNN 구현

def rnn_cell(params, h, x):
    '''단일 step'''
    h_new = jnp.tanh(params["W_h"] @ h + params["W_x"] @ x + params["b"])
    return h_new

def rnn_forward(params, h0, xs):
    '''xs: (T, input_dim) — sequence'''
    def step(h, x):
        h_new = rnn_cell(params, h, x)
        return h_new, h_new   # carry: hidden, output: hidden 도

    final_h, hs = jax.lax.scan(step, h0, xs)
    return hs   # (T, hidden_dim)

scan with input from each step + 외부 input

# 학습 loop 도 scan 으로
def train_step_for_scan(state, batch):
    # batch 는 매 step 의 input
    new_state, loss, _ = train_step(state, batch)
    return new_state, loss

# 한 epoch — 100 step 의 batch 들
batches = get_all_batches()  # shape: (100, B, ...)
final_state, losses = jax.lax.scan(train_step_for_scan, state, batches)

학습 step 이 매번 전혀 안 변하면 — Python for 보다 scan 으로 — host overhead 거의 0.

cumsum 같은 누적

# jnp 에 cumsum 있지만, 일반화된 누적 연산:
def custom_cumsum(xs):
    def step(acc, x):
        new_acc = acc + x
        return new_acc, new_acc

    _, cumulative = jax.lax.scan(step, 0.0, xs)
    return cumulative

# exponential moving average
def ema(xs, alpha=0.1):
    def step(prev, x):
        new = alpha * x + (1 - alpha) * prev
        return new, new

    _, smooth = jax.lax.scan(step, xs[0], xs[1:])
    return jnp.concatenate([xs[:1], smooth])

fori_loop / while_loop — 출력 누적 안 할 때

# fori_loop — 정해진 N 번 반복, 마지막 state 만
def newton_iteration(f, f_prime, x0, n_iter=20):
    def body(i, x):
        return x - f(x) / f_prime(x)
    return jax.lax.fori_loop(0, n_iter, body, x0)

# while_loop — 조건 종료
def find_root(f, f_prime, x0, tol=1e-6, max_iter=100):
    def cond(state):
        x, i = state
        return (jnp.abs(f(x)) > tol) & (i < max_iter)
    def body(state):
        x, i = state
        return (x - f(x) / f_prime(x), i + 1)
    final_x, _ = jax.lax.while_loop(cond, body, (x0, 0))
    return final_x

📐 scan / fori_loop / while_loop 선택

(1) scan — 매 step 의 output 도 모음. RNN, training loop. (2) fori_loop — 정해진 N 번, output 누적 안 함. iteration. (3) while_loop — 조건 종료, dynamic 길이. iterative solver, BFS. 셋 다 jit 안에서 컴파일 단계 1 번. Python for-loop 보다 — 큰 N 에서 절대적으로 빠름.

한 가지 — scan 의 length 는 trace 시점에 정해져야 함. dynamic 길이는 — while_loop 또는 padded scan + masking.

Code

import jax
import jax.numpy as jnp

# Scan replaces sequential loops with a compiled operation
# Pattern: carry, output = scan(fn, init_carry, inputs)

# Example: running sum (like cumsum but showing the pattern)
def running_sum(xs):
    def step(carry, x):
        new_carry = carry + x
        return new_carry, new_carry  # (new_carry, output)

    final_carry, outputs = jax.lax.scan(step, init=0.0, xs=xs)
    return outputs

xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
print(running_sum(xs))  # [1. 3. 6. 10. 15.]

# Example: apply N identical layers (weight sharing)
def apply_shared_layers(params, x, num_layers):
    def layer_fn(x, _):
        x = jax.nn.relu(x @ params['w'] + params['b'])
        return x, None  # carry = next input, no per-step output

    final_x, _ = jax.lax.scan(layer_fn, x, xs=None, length=num_layers)
    return final_x
# while_loop for dynamic stopping conditions
def newton_sqrt(x, tol=1e-6):
    """Newton's method for square root — unknown number of iterations."""
    def cond_fn(state):
        guess, _ = state
        return jnp.abs(guess * guess - x) > tol

    def body_fn(state):
        guess, count = state
        guess = (guess + x / guess) / 2.0
        return guess, count + 1

    init_state = (x / 2.0, 0)
    final_guess, num_iters = jax.lax.while_loop(cond_fn, body_fn, init_state)
    return final_guess, num_iters

# Works under JIT!
sqrt_5, iters = jax.jit(newton_sqrt)(5.0)
print(f"sqrt(5) ≈ {sqrt_5:.6f} in {iters} iterations")
# Conditional execution under JIT
def safe_divide(x, y):
    return jax.lax.cond(
        y != 0,
        lambda: x / y,           # true branch
        lambda: jnp.float32(0),  # false branch
    )

# Multi-way switch (like a compiled if/elif/else)
def activation(x, choice):
    return jax.lax.switch(
        choice,
        [
            lambda x: jax.nn.relu(x),     # choice=0
            lambda x: jax.nn.gelu(x),     # choice=1
            lambda x: jnp.tanh(x),        # choice=2
            lambda x: jax.nn.swish(x),    # choice=3
        ],
        x,
    )

External links

Exercise

RNN forward pass 를 3 가지로: (1) Python for, (2) jax.lax.scan, (3) jax.lax.fori_loop. seq length 1024 측정. 각각 right tool 인 시점. lesson 은 성능보다 — 패턴 인식.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.