jax.tree.map 은 — pytree 작업의 핵심 도구. 모든 leaf 에 함수 적용, 트리 구조는 보존.
import jax
import jax.numpy as jnp
# 단일 pytree
params = {
"W": jnp.array([[1.0, 2.0], [3.0, 4.0]]),
"b": jnp.array([0.5, -0.5]),
}
# 모든 leaf 를 2 배
doubled = jax.tree.map(lambda x: x * 2, params)
# {"W": [[2,4],[6,8]], "b": [1,-1]}
# 모든 leaf 의 norm
norms = jax.tree.map(jnp.linalg.norm, params)
# {"W": ..., "b": ...} — 각 leaf 의 norm
여러 pytree
params = {"W": jnp.zeros((3,4)), "b": jnp.zeros(4)}
grads = {"W": jnp.ones((3,4)), "b": jnp.ones(4)}
# 같은 구조의 두 pytree 를 zip-map
new_params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
# 세 개도 OK
old_params = ...
update = jax.tree.map(
lambda p_old, p_new, g: 0.9 * p_old + 0.1 * (p_new - g),
old_params, params, grads,
)
구조가 다르면 에러
params = {"W": ..., "b": ...}
grads = {"W": ..., "bias": ...} # key 다름!
jax.tree.map(lambda p, g: p - g, params, grads)
# ValueError: Tree structures do not match
실전 패턴
# 1. zero-out
zeroed = jax.tree.map(jnp.zeros_like, params)
# 2. clip gradients
def clip(g, max_norm=1.0):
norm = jnp.linalg.norm(g)
return jnp.where(norm > max_norm, g * max_norm / norm, g)
clipped = jax.tree.map(clip, grads)
# 3. add noise
def add_noise(p, key, scale=0.01):
return p + scale * jax.random.normal(key, p.shape)
# 각 leaf 에 다른 key 필요 — 좀 까다로워
keys = jax.random.split(key, len(jax.tree.leaves(params)))
key_tree = jax.tree.unflatten(jax.tree.structure(params), list(keys))
noisy = jax.tree.map(add_noise, params, key_tree)
# 4. parameter 통계
total = sum(x.size for x in jax.tree.leaves(params))
total_norm = jnp.sqrt(sum(jnp.sum(x**2) for x in jax.tree.leaves(params)))
tree.reduce
# 모든 leaf 의 sum-of-squares
sq_sum = jax.tree.reduce(
lambda acc, x: acc + jnp.sum(x**2),
params,
initializer=0.0,
)
# 또는 leaves 로 list 만든 후 reduce
sq_sum = sum(jnp.sum(x**2) for x in jax.tree.leaves(params))
tree.map_with_path
# leaf 의 위치 (path) 도 함께
def f(path, value):
print(f"path={path}, value.shape={value.shape}")
return value
jax.tree.map_with_path(f, params)
# path=('W',), value.shape=(3, 4)
# path=('b',), value.shape=(4,)
특정 layer 만 별도 처리할 때 유용 — "이름이 'bias' 인 leaf 는 weight decay 안 함" 같은 패턴.
💡 tree.map 은 jit 안에서 자유
jit 안에서 tree.map 호출 OK. trace 시점에 tree 가 풀려서, 같은 op 을 모든 leaf 에 적용한 IR 이 생성됨. params 가 100 개 leaf 면 IR 도 그만큼 커짐 — compile 시간이 늘 수 있음. 매우 큰 model 에서 신경 써야 할 부분.