C.W.K.
Stream
Lesson 02 of 05 · published

jax.tree.map: pytree 변환

~8 min · pytrees, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

jax.tree.map 은 — pytree 작업의 핵심 도구. 모든 leaf 에 함수 적용, 트리 구조는 보존.

import jax
import jax.numpy as jnp

# 단일 pytree
params = {
    "W": jnp.array([[1.0, 2.0], [3.0, 4.0]]),
    "b": jnp.array([0.5, -0.5]),
}

# 모든 leaf 를 2 배
doubled = jax.tree.map(lambda x: x * 2, params)
# {"W": [[2,4],[6,8]], "b": [1,-1]}

# 모든 leaf 의 norm
norms = jax.tree.map(jnp.linalg.norm, params)
# {"W": ..., "b": ...} — 각 leaf 의 norm

여러 pytree

params = {"W": jnp.zeros((3,4)), "b": jnp.zeros(4)}
grads  = {"W": jnp.ones((3,4)),  "b": jnp.ones(4)}

# 같은 구조의 두 pytree 를 zip-map
new_params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)

# 세 개도 OK
old_params = ...
update = jax.tree.map(
    lambda p_old, p_new, g: 0.9 * p_old + 0.1 * (p_new - g),
    old_params, params, grads,
)

구조가 다르면 에러

params = {"W": ..., "b": ...}
grads  = {"W": ..., "bias": ...}   # key 다름!

jax.tree.map(lambda p, g: p - g, params, grads)
# ValueError: Tree structures do not match

실전 패턴

# 1. zero-out
zeroed = jax.tree.map(jnp.zeros_like, params)

# 2. clip gradients
def clip(g, max_norm=1.0):
    norm = jnp.linalg.norm(g)
    return jnp.where(norm > max_norm, g * max_norm / norm, g)

clipped = jax.tree.map(clip, grads)

# 3. add noise
def add_noise(p, key, scale=0.01):
    return p + scale * jax.random.normal(key, p.shape)

# 각 leaf 에 다른 key 필요 — 좀 까다로워
keys = jax.random.split(key, len(jax.tree.leaves(params)))
key_tree = jax.tree.unflatten(jax.tree.structure(params), list(keys))
noisy = jax.tree.map(add_noise, params, key_tree)

# 4. parameter 통계
total = sum(x.size for x in jax.tree.leaves(params))
total_norm = jnp.sqrt(sum(jnp.sum(x**2) for x in jax.tree.leaves(params)))

tree.reduce

# 모든 leaf 의 sum-of-squares
sq_sum = jax.tree.reduce(
    lambda acc, x: acc + jnp.sum(x**2),
    params,
    initializer=0.0,
)

# 또는 leaves 로 list 만든 후 reduce
sq_sum = sum(jnp.sum(x**2) for x in jax.tree.leaves(params))

tree.map_with_path

# leaf 의 위치 (path) 도 함께
def f(path, value):
    print(f"path={path}, value.shape={value.shape}")
    return value

jax.tree.map_with_path(f, params)
# path=('W',), value.shape=(3, 4)
# path=('b',), value.shape=(4,)

특정 layer 만 별도 처리할 때 유용 — "이름이 'bias' 인 leaf 는 weight decay 안 함" 같은 패턴.

💡 tree.map 은 jit 안에서 자유

jit 안에서 tree.map 호출 OK. trace 시점에 tree 가 풀려서, 같은 op 을 모든 leaf 에 적용한 IR 이 생성됨. params 가 100 개 leaf 면 IR 도 그만큼 커짐 — compile 시간이 늘 수 있음. 매우 큰 model 에서 신경 써야 할 부분.

Code

import jax
import jax.numpy as jnp

params = {
    'layer1': {'w': jnp.ones((3, 4)), 'b': jnp.zeros(4)},
    'layer2': {'w': jnp.ones((4, 2)), 'b': jnp.zeros(2)},
}

# Apply a function to every leaf
scaled = jax.tree.map(lambda x: x * 2.0, params)
print(scaled['layer1']['w'][0, 0])  # 2.0

# Count parameters
total_params = sum(jax.tree.map(lambda x: x.size, params).values()
                   for layer in jax.tree.leaves(
                       jax.tree.map(lambda x: x.size, params)))
# Better way:
total = sum(x.size for x in jax.tree.leaves(params))
print(f"Total parameters: {total}")  # 3*4 + 4 + 4*2 + 2 = 22
# SGD update: params = params - lr * grads
def sgd_update(params, grads, lr=0.01):
    return jax.tree.map(lambda p, g: p - lr * g, params, grads)

# Compute gradients (returns a pytree matching params structure)
def loss_fn(params, x, y):
    h = jax.nn.relu(x @ params['layer1']['w'] + params['layer1']['b'])
    pred = h @ params['layer2']['w'] + params['layer2']['b']
    return jnp.mean((pred - y) ** 2)

x = jnp.ones((5, 3))
y = jnp.ones((5, 2))
grads = jax.grad(loss_fn)(params, x, y)

# grads has SAME structure as params
new_params = sgd_update(params, grads)
# new_params also has the same structure!
# Modern API (recommended)
result = jax.tree.map(fn, tree)

# Legacy API (still works)
result = jax.tree_util.tree_map(fn, tree)

# Other useful functions in the modern API
leaves = jax.tree.leaves(tree)            # flat list of leaves
structure = jax.tree.structure(tree)      # just the tree structure
flat, treedef = jax.tree.flatten(tree)    # leaves + structure

External links

Exercise

model params 의 pytree (dict of jnp arrays). jax.tree.map 으로 (1) zero, (2) Gaussian noise 추가, (3) leaf 마다 norm clip. jit 합성. 같은 primitive 가 Flax/Optax 내부에서 사용.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.