pytree 를 list of leaves 로 풀고, 다시 묶을 수 있어 — debugging / serialization / 임의 manipulation 에 필수.
import jax
params = {
"W1": jnp.zeros((3, 4)),
"b1": jnp.zeros(4),
"W2": jnp.zeros((4, 2)),
"b2": jnp.zeros(2),
}
# 1. flatten — leaves + structure
leaves, treedef = jax.tree.flatten(params)
# leaves: [W1, b1, W2, b2] (list of arrays)
# treedef: 트리 구조 정보 (PyTreeDef)
print(f"leaves: {len(leaves)}")
for i, l in enumerate(leaves):
print(f" [{i}] shape={l.shape}")
# 2. unflatten — leaves 와 structure 로 재구성
restored = jax.tree.unflatten(treedef, leaves)
# restored 는 원본 params 와 동일
# 3. tree.leaves / tree.structure (단축)
just_leaves = jax.tree.leaves(params)
just_struct = jax.tree.structure(params)
활용 1: 모든 parameter 를 한 vector 로
def params_to_flat_vector(params):
'''pytree → 1D vector'''
leaves = jax.tree.leaves(params)
return jnp.concatenate([l.ravel() for l in leaves])
def flat_vector_to_params(flat, template):
'''1D vector → pytree (template 의 구조 사용)'''
leaves = jax.tree.leaves(template)
sizes = [l.size for l in leaves]
shapes = [l.shape for l in leaves]
treedef = jax.tree.structure(template)
chunks = jnp.split(flat, jnp.cumsum(jnp.array(sizes))[:-1])
new_leaves = [c.reshape(s) for c, s in zip(chunks, shapes)]
return jax.tree.unflatten(treedef, new_leaves)
# 사용
flat = params_to_flat_vector(params)
print(flat.shape) # (총 parameter 개수,)
restored = flat_vector_to_params(flat, params)
2nd-order optimizer (LBFGS, Newton's method) 가 — pytree 를 flat vector 로 다뤄야 할 때 사용. jax.flatten_util.ravel_pytree 가 같은 일을 한 줄로 해 줌:
from jax.flatten_util import ravel_pytree
flat, unravel = ravel_pytree(params)
# flat: 1D vector
# unravel: 함수, flat → 원래 pytree
# 사용
restored = unravel(flat)
new_flat = flat + 0.01 # 어떤 transform
new_params = unravel(new_flat)
활용 2: 디버깅 — tree 구조 출력
def print_tree(tree, prefix=""):
if isinstance(tree, dict):
for k, v in tree.items():
if isinstance(v, dict):
print(f"{prefix}{k}/")
print_tree(v, prefix + " ")
else:
print(f"{prefix}{k}: shape={v.shape}, dtype={v.dtype}")
print_tree(params)
# W1: shape=(3, 4), dtype=float32
# b1: shape=(4,), dtype=float32
# ...
# 또는 한 줄 — leaves 와 path 함께
import jax.tree_util
def display(p, v):
return f"{'.'.join(str(x.key) for x in p)}: {v.shape}"
jax.tree_util.tree_map_with_path(display, params)
활용 3: parameter count
def count_params(params):
return sum(x.size for x in jax.tree.leaves(params))
n = count_params(transformer_params)
print(f"{n:,} parameters")
# 예: 124,440,000 (124M model)
# layer 별
for path, leaf in jax.tree_util.tree_leaves_with_path(params):
print(f"{path}: {leaf.size:,}")
🔧 flatten/unflatten 의 정신
pytree 가 매력적인 이유 — 두 가지를 분리. 한 쪽에는 "구조" (어떤 모양으로 nested 되어 있나), 한 쪽에는 "값" (실제 array). 같은 구조에 다른 값을 끼워 넣을 수 있어 — checkpoint loading, parameter swapping, 모든 manipulation 이 깨끗.
주의 — treedef 자체는 hashable 이라 jit 의 static_argnames 에 넣을 수 있음. 그래서 trace 안에서 tree 구조가 자동으로 캐싱.