C.W.K.
Stream
Lesson 01 of 05 · published

Pytree 가 뭐야?

~8 min · pytrees, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 의 거의 모든 함수가 받는 자료 구조 — pytree. 이름은 거창하지만 — 그냥 "중첩된 Python container 의 트리, leaf 가 array".

import jax
import jax.numpy as jnp

# 단순 array 도 pytree
a = jnp.array([1, 2, 3])

# dict — pytree
params = {
    "W": jnp.zeros((3, 4)),
    "b": jnp.zeros((4,)),
}

# list — pytree
weights = [jnp.zeros((3, 4)), jnp.zeros((4, 2))]

# tuple — pytree
state = (jnp.zeros(()), jnp.zeros((10,)))

# nested — 다 pytree
deep = {
    "encoder": [jnp.zeros((3, 8)), jnp.zeros((8, 4))],
    "decoder": [jnp.zeros((4, 8)), jnp.zeros((8, 3))],
    "stats": (jnp.array(0), jnp.array(0.0)),
}

JAX 가 pytree 를 다룰 수 있는 컨테이너:

  • list, tuple, dict
  • collections.OrderedDict, defaultdict
  • NamedTuple
  • @dataclass (등록 후)
  • None (leaf 가 없는 빈 자리)

"leaf" 는 pytree 에서 더 이상 풀 수 없는 단위 — 보통 jnp.array 또는 Python scalar.

주요 함수 — jax.tree.map

params = {
    "W": jnp.zeros((3, 4)),
    "b": jnp.zeros((4,)),
}

# 모든 leaf 에 함수 적용
doubled = jax.tree.map(lambda x: x * 2, params)
# {"W": (3,4) of 0, "b": (4,) of 0}  — 0 이라 다 0 이지만, 모양 보존

zeros = jax.tree.map(jnp.zeros_like, params)
# 같은 모양의 새 pytree, 모두 0

ones = jax.tree.map(jnp.ones_like, params)

여러 pytree 를 동시에

grads = {
    "W": jnp.ones((3, 4)),
    "b": jnp.ones((4,)),
}

# params 와 grads 를 zip 해서 update
new_params = jax.tree.map(
    lambda p, g: p - 0.01 * g,
    params,
    grads,
)
# 두 pytree 의 구조가 같아야 함 (key, shape, ...)

이게 — 모든 JAX optimizer / 학습 코드 의 핵심 패턴. params 가 어떤 모양이든 (단순 array, dict, NamedTuple, deep nested) — 같은 코드.

🌳 pytree 의 의의

JAX 가 PyTorch 와 가장 다른 점 중 하나 — model parameters 를 dict / list / dataclass 로 자유롭게 표현. nn.Module 클래스 상속 안 해도 됨. 모든 transformation (jit, grad, vmap) 이 pytree 를 자동으로 walk. 모델이 더 이상 magic 한 객체가 아니라 — 그냥 nested 데이터.

실용 예 — Transformer block 의 params:

block_params = {
    "attention": {
        "W_q": jnp.zeros((d, d)),
        "W_k": jnp.zeros((d, d)),
        "W_v": jnp.zeros((d, d)),
        "W_o": jnp.zeros((d, d)),
    },
    "mlp": {
        "W1": jnp.zeros((d, 4*d)),
        "W2": jnp.zeros((4*d, d)),
    },
    "ln1": {"gamma": jnp.ones(d), "beta": jnp.zeros(d)},
    "ln2": {"gamma": jnp.ones(d), "beta": jnp.zeros(d)},
}

# zero gradient
zero_grads = jax.tree.map(jnp.zeros_like, block_params)

# parameter 개수 세기
n_params = sum(x.size for x in jax.tree.leaves(block_params))
print(f"params: {n_params:,}")

모든 게 dict 안의 dict 라 — 자유롭게 inspecting / manipulating. 이게 functional 모델 표현의 자유도.

Code

import jax
import jax.numpy as jnp

# This is a pytree: nested dicts and lists containing arrays
model_params = {
    'encoder': {
        'weights': jnp.ones((784, 256)),
        'bias': jnp.zeros(256),
    },
    'decoder': {
        'weights': jnp.ones((256, 784)),
        'bias': jnp.zeros(784),
    },
}

# This is also a pytree: a list of tuples
layers = [
    (jnp.ones((3, 4)), jnp.zeros(4)),
    (jnp.ones((4, 2)), jnp.zeros(2)),
]

# Even a single array is a (trivial) pytree
single = jnp.array([1.0, 2.0, 3.0])

# You can see the leaves of any pytree
leaves = jax.tree.leaves(model_params)
print(f"Number of leaves: {len(leaves)}")  # 4 arrays
print(f"Shapes: {[l.shape for l in leaves]}")
# [(784, 256), (256,), (256, 784), (784,)]

External links

Exercise

jnp array 의 nested dict + list 로 pytree 수동 정의. jax.tree.map(lambda x: x*2, tree) 로 통과. output 구조 검사 — 같은 모양, leaf 마다 두 배. JAX 의 가장 자주 쓰는 utility.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.