C.W.K.
Stream
Lesson 04 of 05 · published

Custom Pytree 노드

~11 min · pytrees, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

dict / list / tuple 외에 — 직접 만든 클래스도 pytree 로 등록 가능. @dataclass 가 자주 쓰임.

방법 1: @register_pytree_node_class

import jax
from dataclasses import dataclass

@jax.tree_util.register_pytree_node_class
@dataclass
class TrainState:
    params: dict
    opt_state: dict
    step: int

    def tree_flatten(self):
        children = (self.params, self.opt_state, self.step)
        aux_data = None   # 정적 메타데이터 (없음)
        return children, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

# 사용
state = TrainState(
    params={"w": jnp.zeros(10)},
    opt_state={"momentum": jnp.zeros(10)},
    step=0,
)

# pytree op 동작
zeroed = jax.tree.map(jnp.zeros_like, state)   # ← 이제 정상 동작
flat, treedef = jax.tree.flatten(state)

방법 2: register_pytree_node (함수형)

class MyState:
    def __init__(self, x, y):
        self.x = x
        self.y = y

def state_flatten(s):
    return (s.x, s.y), None

def state_unflatten(aux, children):
    return MyState(*children)

jax.tree_util.register_pytree_node(MyState, state_flatten, state_unflatten)

방법 3: Flax / Equinox 식 dataclass

가장 흔한 pattern — chex.dataclass 또는 flax.struct.dataclass:

from flax import struct

@struct.dataclass
class TrainState:
    params: dict
    opt_state: dict
    step: int

# 자동으로 pytree 등록됨
state = TrainState(params={"w": jnp.zeros(10)}, opt_state={}, step=0)
zeroed = jax.tree.map(jnp.zeros_like, state)   # OK

또는 equinox.Module:

import equinox as eqx

class MLP(eqx.Module):
    layer1: eqx.nn.Linear
    layer2: eqx.nn.Linear

    def __init__(self, key):
        k1, k2 = jax.random.split(key)
        self.layer1 = eqx.nn.Linear(10, 20, key=k1)
        self.layer2 = eqx.nn.Linear(20, 5, key=k2)

    def __call__(self, x):
        return self.layer2(jax.nn.relu(self.layer1(x)))

model = MLP(jax.random.PRNGKey(0))
# model 자체가 pytree
flat, _ = jax.tree.flatten(model)

aux_data 의 역할

tree_flatten 의 두 번째 반환값 — 정적 메타데이터. trace 시점에 hash 됨 (cache key 의 일부):

@jax.tree_util.register_pytree_node_class
@dataclass
class FixedShape:
    data: jnp.ndarray
    shape_info: tuple   # 정적 (hash 가능)

    def tree_flatten(self):
        return (self.data,), self.shape_info   # data 는 dynamic, shape 은 static

    @classmethod
    def tree_unflatten(cls, aux, children):
        return cls(children[0], aux)

aux_data 가 바뀌면 — jit 이 새 trace. 그래서 — 자주 안 바뀌는 메타데이터만 aux_data 로.

💡 언제 custom pytree?

(1) train state — flax struct 또는 chex dataclass. (2) NN model class — equinox 가 자동. (3) RL state — 직접 register. 그러나 — 가급적 dict 로 시작해. dict 가 가장 단순. dict 가 답답해질 때 dataclass 로 옮겨. 처음부터 custom 만들 필요 없음.

중요한 한 가지 — pytree 등록 후엔 — 모든 JAX 변환 (jit, grad, vmap, scan 등) 이 자동으로 walk. PyTorch 의 nn.Module 같은 별도 추상화 없이 — pytree + 함수 만으로 충분.

Code

import jax
import jax.numpy as jnp
from functools import partial

class LinearParams:
    """A simple container for linear layer parameters."""
    def __init__(self, weights, bias):
        self.weights = weights
        self.bias = bias

    def __repr__(self):
        return f"LinearParams(w={self.weights.shape}, b={self.bias.shape})"

# Without registration, JAX sees this as an opaque leaf
layer = LinearParams(jnp.ones((3, 4)), jnp.zeros(4))
print(jax.tree.leaves(layer))  # [LinearParams(...)] — the object itself!
def linear_flatten(obj):
    """Returns (children, aux_data)."""
    children = (obj.weights, obj.bias)  # the arrays JAX should trace
    aux_data = None                      # any static metadata (none here)
    return children, aux_data

def linear_unflatten(aux_data, children):
    """Reconstructs the object from (aux_data, children)."""
    weights, bias = children
    return LinearParams(weights, bias)

# Register with JAX
jax.tree_util.register_pytree_node(
    LinearParams,
    linear_flatten,
    linear_unflatten,
)

# Now JAX sees inside it
layer = LinearParams(jnp.ones((3, 4)), jnp.zeros(4))
print(jax.tree.leaves(layer))
# [Array([[1., ...]], dtype=float32), Array([0., ...], dtype=float32)]

# tree.map works too
doubled = jax.tree.map(lambda x: x * 2, layer)
print(doubled)  # LinearParams(w=(3, 4), b=(4,)) with doubled values
# Example with aux_data: a layer that stores its activation name
class DenseParams:
    def __init__(self, weights, bias, activation='relu'):
        self.weights = weights
        self.bias = bias
        self.activation = activation  # static config

def dense_flatten(obj):
    children = (obj.weights, obj.bias)
    aux_data = obj.activation  # stored as static metadata
    return children, aux_data

def dense_unflatten(aux_data, children):
    weights, bias = children
    return DenseParams(weights, bias, activation=aux_data)

jax.tree_util.register_pytree_node(DenseParams, dense_flatten, dense_unflatten)

# Now grad works through DenseParams
layer = DenseParams(jnp.ones((3, 4)), jnp.zeros(4), activation='relu')
grads = jax.grad(lambda l: jnp.sum(l.weights))(layer)
print(grads.weights)  # all ones
print(grads.activation)  # 'relu' — static data passed through unchanged

External links

Exercise

tiny model state 의 @dataclass. register_pytree_node_class 로 custom pytree 등록. jax.tree.map 통과. 인자로 받는 함수 jit. 이거 하기 전엔 pytree 가 dict-only 같았을 거 — 이제 universal.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.