JAX 의 거의 모든 함수가 받는 자료 구조 — pytree. 이름은 거창하지만 — 그냥 "중첩된 Python container 의 트리, leaf 가 array".
import jax
import jax.numpy as jnp
# 단순 array 도 pytree
a = jnp.array([1, 2, 3])
# dict — pytree
params = {
"W": jnp.zeros((3, 4)),
"b": jnp.zeros((4,)),
}
# list — pytree
weights = [jnp.zeros((3, 4)), jnp.zeros((4, 2))]
# tuple — pytree
state = (jnp.zeros(()), jnp.zeros((10,)))
# nested — 다 pytree
deep = {
"encoder": [jnp.zeros((3, 8)), jnp.zeros((8, 4))],
"decoder": [jnp.zeros((4, 8)), jnp.zeros((8, 3))],
"stats": (jnp.array(0), jnp.array(0.0)),
}
JAX 가 pytree 를 다룰 수 있는 컨테이너:
list,tuple,dictcollections.OrderedDict,defaultdictNamedTuple@dataclass(등록 후)None(leaf 가 없는 빈 자리)
"leaf" 는 pytree 에서 더 이상 풀 수 없는 단위 — 보통 jnp.array 또는 Python scalar.
주요 함수 — jax.tree.map
params = {
"W": jnp.zeros((3, 4)),
"b": jnp.zeros((4,)),
}
# 모든 leaf 에 함수 적용
doubled = jax.tree.map(lambda x: x * 2, params)
# {"W": (3,4) of 0, "b": (4,) of 0} — 0 이라 다 0 이지만, 모양 보존
zeros = jax.tree.map(jnp.zeros_like, params)
# 같은 모양의 새 pytree, 모두 0
ones = jax.tree.map(jnp.ones_like, params)
여러 pytree 를 동시에
grads = {
"W": jnp.ones((3, 4)),
"b": jnp.ones((4,)),
}
# params 와 grads 를 zip 해서 update
new_params = jax.tree.map(
lambda p, g: p - 0.01 * g,
params,
grads,
)
# 두 pytree 의 구조가 같아야 함 (key, shape, ...)
이게 — 모든 JAX optimizer / 학습 코드 의 핵심 패턴. params 가 어떤 모양이든 (단순 array, dict, NamedTuple, deep nested) — 같은 코드.
🌳 pytree 의 의의
JAX 가 PyTorch 와 가장 다른 점 중 하나 — model parameters 를 dict / list / dataclass 로 자유롭게 표현. nn.Module 클래스 상속 안 해도 됨. 모든 transformation (jit, grad, vmap) 이 pytree 를 자동으로 walk. 모델이 더 이상 magic 한 객체가 아니라 — 그냥 nested 데이터.
실용 예 — Transformer block 의 params:
block_params = {
"attention": {
"W_q": jnp.zeros((d, d)),
"W_k": jnp.zeros((d, d)),
"W_v": jnp.zeros((d, d)),
"W_o": jnp.zeros((d, d)),
},
"mlp": {
"W1": jnp.zeros((d, 4*d)),
"W2": jnp.zeros((4*d, d)),
},
"ln1": {"gamma": jnp.ones(d), "beta": jnp.zeros(d)},
"ln2": {"gamma": jnp.ones(d), "beta": jnp.zeros(d)},
}
# zero gradient
zero_grads = jax.tree.map(jnp.zeros_like, block_params)
# parameter 개수 세기
n_params = sum(x.size for x in jax.tree.leaves(block_params))
print(f"params: {n_params:,}")
모든 게 dict 안의 dict 라 — 자유롭게 inspecting / manipulating. 이게 functional 모델 표현의 자유도.