C.W.K.
Stream
Lesson 02 of 05 · published

Tracing 깊이 보기: Abstract Values 와 ShapedArray

~8 min · jit, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

trace 가 어떻게 일어나는지 들여다보면 — JAX 의 모든 동작이 자연스러워져.

jit 함수를 처음 호출하면 — 인자가 실제 array 가 아니라 Tracer object 로 대체됨. Tracer 는 array 와 비슷하지만 — concrete 한 값 대신 abstract 한 metadata 만 가짐:

  • shape — 모양 (예: (32, 768))
  • dtype — type (예: float32)
  • 그게 다.

이 metadata 를 ShapedArray 라고 불러. 함수 안에서 일어나는 모든 연산은 ShapedArray 위에서 — 실제 계산이 아니라 — "어떤 op 이 어떤 shape/dtype 을 만들지" 추적.

import jax

@jax.jit
def f(x):
    print(type(x))   # 첫 호출에서 한 번만 출력. 무엇이 찍힐까?
    print(x)
    return jnp.sum(x ** 2)

f(jnp.arange(5.))
# <class 'jax.interpreters.partial_eval.JaxprTracer'>
# Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=1/0)>

Python 의 print 는 trace 시점에 한 번 도니까 — 여기서 보이는 게 Tracer object 의 정체. x 는 5 개짜리 float32 array 라는 건 알지만, 값은 모름.

이게 왜 중요하냐면 — Tracer 의 한계 때문에 Python 의 일부 동작이 안 됨:

@jax.jit
def f(x):
    if x > 0:        # ❌ Tracer 의 truth value 가 없음
        return x
    return -x

# 해결
@jax.jit
def f(x):
    return jnp.where(x > 0, x, -x)  # value 사용 안 함, abstract op

Tracer 가 가진 정보로 가능한 것:

  • x.shape — OK (static)
  • x.dtype — OK
  • x.ndim — OK
  • jnp.sum(x) — OK (op 추가, abstract)
  • x[0] — OK if index is static

안 되는 것:

  • x.item() — concrete 값 필요
  • int(x), float(x) — concrete 값
  • if x > 0: — concrete bool 필요
  • x.numpy() — concrete

대부분의 함수는 trace 가능 — abstract value 로도 op 흐름은 같으니까. 그래서 거의 모든 NumPy-style 코드가 그대로 jit 가능.

🔬 Jaxpr 보기

jax.make_jaxpr(f)(x) 호출하면 — JAX 가 trace 한 IR 을 인쇄. compile 안 하고 trace 만. 함수의 op 흐름이 정확히 보임. 디버깅 + 성능 분석에 유용. print(jax.make_jaxpr(f)(x)) — 한 번 해 봐.

import jax

def f(x):
    return jnp.sum(x ** 2 + jnp.sin(x))

print(jax.make_jaxpr(f)(jnp.arange(5.)))
# { lambda ; a:f32[5]. let
#     b:f32[5] = integer_pow[y=2] a
#     c:f32[5] = sin a
#     d:f32[5] = add b c
#     e:f32[] = reduce_sum[axes=(0,)] d
#   in (e,) }

읽는 법: let <var>:<type> = <op>[<params>] <args>. 5 개짜리 float32 가 들어와서 power, sin, add 가 차례로 적용되고 reduce_sum 으로 합쳐져 0-d float32 (scalar) 가 나옴. compile 단계에선 이 IR 이 XLA 에 넘어가.

이 mental model 한 번 잡으면 — JAX 의 거의 모든 에러 메시지가 해석 가능해.

Code

import jax
import jax.numpy as jnp

@jax.jit
def show_tracer(x):
    # During tracing, x is a Tracer wrapping a ShapedArray
    print(f"Type during trace: {type(x)}")
    # Shape is known:
    print(f"Shape: {x.shape}")
    # But values are NOT known:
    # if x[0] > 0: ...  # This would fail!
    return x + 1

show_tracer(jnp.array([1.0, 2.0, 3.0]))
# Prints something like:
# Type during trace: <class 'jax...JaxprTracer'>
# Shape: (3,)
import jax
import jax.numpy as jnp

def my_fn(x, y):
    z = jnp.sin(x) + jnp.cos(y)
    return jnp.sum(z)

# See the traced computation
jaxpr = jax.make_jaxpr(my_fn)(jnp.ones(3), jnp.ones(3))
print(jaxpr)
# { lambda ; a:f32[3] b:f32[3]. let
#     c:f32[3] = sin a
#     d:f32[3] = cos b
#     e:f32[3] = add c d
#     f:f32[] = reduce_sum[axes=(0,)] e
#   in (f,) }

External links

Exercise

jit-decorated 함수 안에 print(x) 추가. 실행. 출력 보고 Tracer object 식별. 그 다음 print 제거하고 jax.debug.print('{x}', x=x). 다시 실행 — 차이 확인. trace-time vs run-time 의 mental model 분리가 여기서 형성.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.