dict / list / tuple 외에 — 직접 만든 클래스도 pytree 로 등록 가능. @dataclass 가 자주 쓰임.
방법 1: @register_pytree_node_class
import jax
from dataclasses import dataclass
@jax.tree_util.register_pytree_node_class
@dataclass
class TrainState:
params: dict
opt_state: dict
step: int
def tree_flatten(self):
children = (self.params, self.opt_state, self.step)
aux_data = None # 정적 메타데이터 (없음)
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
# 사용
state = TrainState(
params={"w": jnp.zeros(10)},
opt_state={"momentum": jnp.zeros(10)},
step=0,
)
# pytree op 동작
zeroed = jax.tree.map(jnp.zeros_like, state) # ← 이제 정상 동작
flat, treedef = jax.tree.flatten(state)
방법 2: register_pytree_node (함수형)
class MyState:
def __init__(self, x, y):
self.x = x
self.y = y
def state_flatten(s):
return (s.x, s.y), None
def state_unflatten(aux, children):
return MyState(*children)
jax.tree_util.register_pytree_node(MyState, state_flatten, state_unflatten)
방법 3: Flax / Equinox 식 dataclass
가장 흔한 pattern — chex.dataclass 또는 flax.struct.dataclass:
from flax import struct
@struct.dataclass
class TrainState:
params: dict
opt_state: dict
step: int
# 자동으로 pytree 등록됨
state = TrainState(params={"w": jnp.zeros(10)}, opt_state={}, step=0)
zeroed = jax.tree.map(jnp.zeros_like, state) # OK
또는 equinox.Module:
import equinox as eqx
class MLP(eqx.Module):
layer1: eqx.nn.Linear
layer2: eqx.nn.Linear
def __init__(self, key):
k1, k2 = jax.random.split(key)
self.layer1 = eqx.nn.Linear(10, 20, key=k1)
self.layer2 = eqx.nn.Linear(20, 5, key=k2)
def __call__(self, x):
return self.layer2(jax.nn.relu(self.layer1(x)))
model = MLP(jax.random.PRNGKey(0))
# model 자체가 pytree
flat, _ = jax.tree.flatten(model)
aux_data 의 역할
tree_flatten 의 두 번째 반환값 — 정적 메타데이터. trace 시점에 hash 됨 (cache key 의 일부):
@jax.tree_util.register_pytree_node_class
@dataclass
class FixedShape:
data: jnp.ndarray
shape_info: tuple # 정적 (hash 가능)
def tree_flatten(self):
return (self.data,), self.shape_info # data 는 dynamic, shape 은 static
@classmethod
def tree_unflatten(cls, aux, children):
return cls(children[0], aux)
aux_data 가 바뀌면 — jit 이 새 trace. 그래서 — 자주 안 바뀌는 메타데이터만 aux_data 로.
💡 언제 custom pytree?
(1) train state — flax struct 또는 chex dataclass. (2) NN model class — equinox 가 자동. (3) RL state — 직접 register. 그러나 — 가급적 dict 로 시작해. dict 가 가장 단순. dict 가 답답해질 때 dataclass 로 옮겨. 처음부터 custom 만들 필요 없음.
중요한 한 가지 — pytree 등록 후엔 — 모든 JAX 변환 (jit, grad, vmap, scan 등) 이 자동으로 walk. PyTorch 의 nn.Module 같은 별도 추상화 없이 — pytree + 함수 만으로 충분.