trace 가 어떻게 일어나는지 들여다보면 — JAX 의 모든 동작이 자연스러워져.
jit 함수를 처음 호출하면 — 인자가 실제 array 가 아니라 Tracer object 로 대체됨. Tracer 는 array 와 비슷하지만 — concrete 한 값 대신 abstract 한 metadata 만 가짐:
- shape — 모양 (예:
(32, 768)) - dtype — type (예:
float32) - 그게 다.
이 metadata 를 ShapedArray 라고 불러. 함수 안에서 일어나는 모든 연산은 ShapedArray 위에서 — 실제 계산이 아니라 — "어떤 op 이 어떤 shape/dtype 을 만들지" 추적.
import jax
@jax.jit
def f(x):
print(type(x)) # 첫 호출에서 한 번만 출력. 무엇이 찍힐까?
print(x)
return jnp.sum(x ** 2)
f(jnp.arange(5.))
# <class 'jax.interpreters.partial_eval.JaxprTracer'>
# Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=1/0)>
Python 의 print 는 trace 시점에 한 번 도니까 — 여기서 보이는 게 Tracer object 의 정체. x 는 5 개짜리 float32 array 라는 건 알지만, 값은 모름.
이게 왜 중요하냐면 — Tracer 의 한계 때문에 Python 의 일부 동작이 안 됨:
@jax.jit
def f(x):
if x > 0: # ❌ Tracer 의 truth value 가 없음
return x
return -x
# 해결
@jax.jit
def f(x):
return jnp.where(x > 0, x, -x) # value 사용 안 함, abstract op
Tracer 가 가진 정보로 가능한 것:
x.shape— OK (static)x.dtype— OKx.ndim— OKjnp.sum(x)— OK (op 추가, abstract)x[0]— OK if index is static
안 되는 것:
x.item()— concrete 값 필요int(x),float(x)— concrete 값if x > 0:— concrete bool 필요x.numpy()— concrete
대부분의 함수는 trace 가능 — abstract value 로도 op 흐름은 같으니까. 그래서 거의 모든 NumPy-style 코드가 그대로 jit 가능.
🔬 Jaxpr 보기
jax.make_jaxpr(f)(x) 호출하면 — JAX 가 trace 한 IR 을 인쇄. compile 안 하고 trace 만. 함수의 op 흐름이 정확히 보임. 디버깅 + 성능 분석에 유용. print(jax.make_jaxpr(f)(x)) — 한 번 해 봐.
import jax
def f(x):
return jnp.sum(x ** 2 + jnp.sin(x))
print(jax.make_jaxpr(f)(jnp.arange(5.)))
# { lambda ; a:f32[5]. let
# b:f32[5] = integer_pow[y=2] a
# c:f32[5] = sin a
# d:f32[5] = add b c
# e:f32[] = reduce_sum[axes=(0,)] d
# in (e,) }
읽는 법: let <var>:<type> = <op>[<params>] <args>. 5 개짜리 float32 가 들어와서 power, sin, add 가 차례로 적용되고 reduce_sum 으로 합쳐져 0-d float32 (scalar) 가 나옴. compile 단계에선 이 IR 이 XLA 에 넘어가.
이 mental model 한 번 잡으면 — JAX 의 거의 모든 에러 메시지가 해석 가능해.