C.W.K.
Stream
Lesson 03 of 05 · published

Pytree Utility: flatten, unflatten, debugging

~10 min · pytrees, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

pytree 를 list of leaves 로 풀고, 다시 묶을 수 있어 — debugging / serialization / 임의 manipulation 에 필수.

import jax

params = {
    "W1": jnp.zeros((3, 4)),
    "b1": jnp.zeros(4),
    "W2": jnp.zeros((4, 2)),
    "b2": jnp.zeros(2),
}

# 1. flatten — leaves + structure
leaves, treedef = jax.tree.flatten(params)
# leaves: [W1, b1, W2, b2]  (list of arrays)
# treedef: 트리 구조 정보 (PyTreeDef)

print(f"leaves: {len(leaves)}")
for i, l in enumerate(leaves):
    print(f"  [{i}] shape={l.shape}")

# 2. unflatten — leaves 와 structure 로 재구성
restored = jax.tree.unflatten(treedef, leaves)
# restored 는 원본 params 와 동일

# 3. tree.leaves / tree.structure (단축)
just_leaves = jax.tree.leaves(params)
just_struct = jax.tree.structure(params)

활용 1: 모든 parameter 를 한 vector 로

def params_to_flat_vector(params):
    '''pytree → 1D vector'''
    leaves = jax.tree.leaves(params)
    return jnp.concatenate([l.ravel() for l in leaves])

def flat_vector_to_params(flat, template):
    '''1D vector → pytree (template 의 구조 사용)'''
    leaves = jax.tree.leaves(template)
    sizes = [l.size for l in leaves]
    shapes = [l.shape for l in leaves]
    treedef = jax.tree.structure(template)

    chunks = jnp.split(flat, jnp.cumsum(jnp.array(sizes))[:-1])
    new_leaves = [c.reshape(s) for c, s in zip(chunks, shapes)]
    return jax.tree.unflatten(treedef, new_leaves)

# 사용
flat = params_to_flat_vector(params)
print(flat.shape)   # (총 parameter 개수,)
restored = flat_vector_to_params(flat, params)

2nd-order optimizer (LBFGS, Newton's method) 가 — pytree 를 flat vector 로 다뤄야 할 때 사용. jax.flatten_util.ravel_pytree 가 같은 일을 한 줄로 해 줌:

from jax.flatten_util import ravel_pytree

flat, unravel = ravel_pytree(params)
# flat: 1D vector
# unravel: 함수, flat → 원래 pytree

# 사용
restored = unravel(flat)
new_flat = flat + 0.01   # 어떤 transform
new_params = unravel(new_flat)

활용 2: 디버깅 — tree 구조 출력

def print_tree(tree, prefix=""):
    if isinstance(tree, dict):
        for k, v in tree.items():
            if isinstance(v, dict):
                print(f"{prefix}{k}/")
                print_tree(v, prefix + "  ")
            else:
                print(f"{prefix}{k}: shape={v.shape}, dtype={v.dtype}")

print_tree(params)
# W1: shape=(3, 4), dtype=float32
# b1: shape=(4,), dtype=float32
# ...

# 또는 한 줄 — leaves 와 path 함께
import jax.tree_util
def display(p, v):
    return f"{'.'.join(str(x.key) for x in p)}: {v.shape}"

jax.tree_util.tree_map_with_path(display, params)

활용 3: parameter count

def count_params(params):
    return sum(x.size for x in jax.tree.leaves(params))

n = count_params(transformer_params)
print(f"{n:,} parameters")
# 예: 124,440,000 (124M model)

# layer 별
for path, leaf in jax.tree_util.tree_leaves_with_path(params):
    print(f"{path}: {leaf.size:,}")

🔧 flatten/unflatten 의 정신

pytree 가 매력적인 이유 — 두 가지를 분리. 한 쪽에는 "구조" (어떤 모양으로 nested 되어 있나), 한 쪽에는 "값" (실제 array). 같은 구조에 다른 값을 끼워 넣을 수 있어 — checkpoint loading, parameter swapping, 모든 manipulation 이 깨끗.

주의 — treedef 자체는 hashable 이라 jit 의 static_argnames 에 넣을 수 있음. 그래서 trace 안에서 tree 구조가 자동으로 캐싱.

Code

import jax
import jax.numpy as jnp

params = {
    'layer1': {'w': jnp.array([[1.0, 2.0], [3.0, 4.0]]), 'b': jnp.array([0.1, 0.2])},
    'layer2': {'w': jnp.array([[5.0], [6.0]]), 'b': jnp.array([0.3])},
}

# Flatten: get leaves and structure separately
leaves, treedef = jax.tree.flatten(params)
print(f"Number of leaves: {len(leaves)}")  # 4
print(f"Tree structure: {treedef}")

# You can reconstruct the original tree
params_rebuilt = treedef.unflatten(leaves)
# params_rebuilt == params (same structure and values)

# Useful: concatenate all params into one vector
all_params = jnp.concatenate([x.ravel() for x in leaves])
print(f"Total params as one vector: {all_params.shape}")  # (9,)
# Get paths alongside leaves
path_leaves, treedef = jax.tree.flatten_with_path(params)

for path, leaf in path_leaves:
    path_str = jax.tree_util.keystr(path)
    print(f"{path_str}: shape={leaf.shape}, dtype={leaf.dtype}")

# Output:
# ['layer1']['b']: shape=(2,), dtype=float32
# ['layer1']['w']: shape=(2, 2), dtype=float32
# ['layer2']['b']: shape=(1,), dtype=float32
# ['layer2']['w']: shape=(2, 1), dtype=float32
# Find NaN parameters
def check_for_nans(params):
    path_leaves, _ = jax.tree.flatten_with_path(params)
    for path, leaf in path_leaves:
        if jnp.any(jnp.isnan(leaf)):
            print(f"NaN found at {jax.tree_util.keystr(path)}!")

# Check parameter statistics
def param_summary(params):
    path_leaves, _ = jax.tree.flatten_with_path(params)
    for path, leaf in path_leaves:
        name = jax.tree_util.keystr(path)
        print(f"{name}: mean={leaf.mean():.4f}, std={leaf.std():.4f}, "
              f"shape={leaf.shape}")

param_summary(params)
# ['layer1']['b']: mean=0.1500, std=0.0500, shape=(2,)
# ['layer1']['w']: mean=2.5000, std=1.1180, shape=(2, 2)
# ...
# PyTorch equivalent: model.state_dict()
# state_dict = model.state_dict()
# for name, param in state_dict.items():
#     print(f"{name}: {param.shape}")
# Output: "layer1.weight: torch.Size([4, 3])"

# JAX equivalent using pytrees:
path_leaves, _ = jax.tree.flatten_with_path(params)
for path, leaf in path_leaves:
    print(f"{jax.tree_util.keystr(path)}: {leaf.shape}")
# Output: "['layer1']['w']: (3, 4)"

# Both give named access to all parameters — different syntax, same idea

External links

Exercise

복잡한 nested params tree. flatten, leaf 개수 셈, jax.tree.reduce 로 총 parameter 합. unflatten 후 동등성 검증. 'model params 개수' 가 이렇게 — Optax 가 내부 state 추적도 같은 패턴.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.