C.W.K.
Stream
Lesson 03 of 05 · published

Static Arguments 와 Control Flow

~9 min · jit, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

jit 의 한계 — Python 의 if/for 같은 control flow 를 trace 시점의 abstract value 로는 결정할 수 없어. 두 가지 도구로 우회:

1. static_argnames / static_argnums — 인자를 "trace 의 일부가 아닌, cache key 의 일부" 로 표시.

from functools import partial

@partial(jax.jit, static_argnames=("mode",))
def f(x, mode):
    if mode == "train":      # OK — mode 는 concrete (static)
        return x * 2
    else:
        return x ** 2

f(jnp.array([1., 2.]), mode="train")   # compile 1
f(jnp.array([1., 2.]), mode="eval")    # compile 2
f(jnp.array([3., 4.]), mode="train")   # cache hit (compile 1)

주의 — static 인자의 값이 바뀔 때마다 새 compile. 너무 다양한 값으로 부르면 cache 폭발 → 메모리 누수처럼 보일 수 있음.

2. jax.lax.cond / jax.lax.switch — runtime 분기를 IR 안에서 표현.

import jax.lax as lax

@jax.jit
def f(x):
    return lax.cond(
        x > 0,                  # predicate (traced 가능)
        lambda x: x * 2,         # true branch
        lambda x: x ** 2,        # false branch
        x,                       # operand
    )

f(jnp.array(3.0))   # 6
f(jnp.array(-2.0))  # 4

둘 다 single jit 으로, runtime 분기. switch 는 N way:

def f(idx, x):
    return lax.switch(
        idx,
        [lambda x: x,
         lambda x: x ** 2,
         lambda x: jnp.sin(x)],
        x,
    )

Loop 는 jax.lax.scan:

# Python for-loop — trace 시점에 unroll. 길면 compile 오래 걸림.
@jax.jit
def f_unroll(x):
    for _ in range(100):
        x = jnp.sin(x)
    return x

# scan — IR 의 loop 로 compile. 길이 무관 한 번만 compile.
@jax.jit
def f_scan(x):
    def step(carry, _):
        return jnp.sin(carry), None
    final, _ = jax.lax.scan(step, x, jnp.zeros(100))
    return final

scan 의 의미: step(carry, x_t) → (new_carry, y_t). RNN, training loop, 누적 등 거의 모든 sequential 연산에 사용.

fori_loop / while_loop 도 있음:

# 정해진 횟수 — fori_loop
def body(i, x):
    return x + jnp.sin(x * i)
result = jax.lax.fori_loop(0, 100, body, init_val)

# 조건 종료 — while_loop
def cond(state): return state["i"] < 100 and state["err"] > 1e-6
def body(state): ...
final = jax.lax.while_loop(cond, body, initial_state)

⚠️ Python for vs lax.scan

짧은 loop (10 회 이하) 는 Python for 로 unroll 해도 OK — compile 시간 거의 같음. 그러나 100, 1000 step 학습 loop 를 Python for 로 jit 하면 — compile 이 분 단위로 걸리고 코드 사이즈가 폭발. 항상 scan 으로.

가이드라인:

  • 인자가 dispatch flag 면 → static_argnames
  • data-dependent 분기면 → jax.lax.cond / jnp.where
  • 고정 길이 loop, 짧음 (≤ 10) → Python for
  • 고정 길이 loop, 김 → jax.lax.fori_loop 또는 scan
  • 가변 길이 / 조건 종료 → jax.lax.while_loop
  • step 마다 output 수집 → jax.lax.scan

Code

import jax
import jax.numpy as jnp
from functools import partial

# Problem: this fails because 'use_bias' is traced
@jax.jit
def linear_bad(x, w, b, use_bias):
    result = x @ w
    if use_bias:  # ConcretizationTypeError: can't use traced bool in if
        result = result + b
    return result

# Solution 1: static_argnums
@partial(jax.jit, static_argnums=(3,))
def linear_v1(x, w, b, use_bias):
    result = x @ w
    if use_bias:  # Now 'use_bias' is concrete — this works!
        result = result + b
    return result

# Solution 2: static_argnames (clearer)
@partial(jax.jit, static_argnames=('use_bias',))
def linear_v2(x, w, b, use_bias):
    result = x @ w
    if use_bias:
        result = result + b
    return result

x = jnp.ones((4, 3))
w = jnp.ones((3, 2))
b = jnp.ones(2)

result = linear_v2(x, w, b, use_bias=True)
print(result.shape)  # (4, 2)
import jax
import jax.numpy as jnp

# jax.lax.cond: if-else for traced values
@jax.jit
def safe_divide(x, y):
    return jax.lax.cond(
        y != 0,
        lambda: x / y,         # true branch
        lambda: jnp.zeros_like(x)  # false branch
    )

# jax.lax.switch: multi-way branch (like switch/case)
@jax.jit
def activation(x, choice):
    return jax.lax.switch(choice, [
        lambda x: x,                          # 0: identity
        lambda x: jnp.maximum(x, 0),          # 1: relu
        lambda x: jnp.tanh(x),                # 2: tanh
        lambda x: jax.nn.sigmoid(x),           # 3: sigmoid
    ], x)

# jax.lax.fori_loop: for loop with traced bounds
@jax.jit
def power(x, n):
    return jax.lax.fori_loop(
        0, n,                    # start, stop (can be traced)
        lambda i, acc: acc * x,  # body: (iteration, carry) -> new_carry
        jnp.ones_like(x)        # initial carry
    )

# jax.lax.scan: the most powerful loop primitive
@jax.jit
def cumulative_product(arr):
    def step(carry, x):
        new_carry = carry * x
        return new_carry, new_carry  # (next carry, output)
    _, products = jax.lax.scan(step, jnp.array(1.0), arr)
    return products

print(cumulative_product(jnp.array([2.0, 3.0, 4.0])))  # [2. 6. 24.]

External links

Exercise

f(x, mode) 함수 — mode 가 string ('train' or 'eval'). static_argnames 없이 jit — error 관찰. 그 다음 static_argnames=('mode',) 추가. mode='unknown' 시도하고 JAX 가 retrace 하는 거 본다. 총 compile 횟수 세기.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.