C.W.K.
Stream
Lesson 03 of 05 · published

성능 Profiling 과 디버깅

~8 min · ecosystem, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

학습이 느릴 때 — 무엇이 병목인가? JAX 가 강력한 profiling 도구 제공.

jax.profiler — TensorBoard trace

import jax
import jax.numpy as jnp

# trace 시작
jax.profiler.start_trace("/tmp/profile-data")

# 측정할 코드
for _ in range(100):
    state, loss = train_step(state, batch)
    state.params["w"].block_until_ready()

# trace 종료
jax.profiler.stop_trace()

그 후 — TensorBoard 로 열기:

pip install tensorboard tensorboard-plugin-profile
tensorboard --logdir /tmp/profile-data
# Profile 탭 → trace_viewer 또는 op_profile

각 op 의 시간, device 활용도, host ↔ device 통신 등 가시화. 학습 step 의 병목을 찾는 데 결정적.

trace_annotation — 코드의 부분에 이름 부여

@jax.profiler.annotate_function
def model_forward(params, x):
    return ...

# 또는 context manager
with jax.profiler.TraceAnnotation("data_loading"):
    batch = next(loader)

with jax.profiler.TraceAnnotation("train_step"):
    state = train_step(state, batch)

profile 결과에서 — 명명된 부분이 별도 표시. "어디가 느린가" 가 시각적.

perfetto / Chrome trace

TensorBoard 가 무거우면 — 더 가벼운 perfetto:

jax.profiler.start_trace("/tmp/perfetto-trace")
# code
jax.profiler.stop_trace()

https://ui.perfetto.dev/ 에서 직접 trace 파일 열기 가능.

compile 시간 분석

import time

@jax.jit
def step(x):
    # 복잡한 계산
    return ...

# 첫 호출 — compile
t = time.time()
y = step(x).block_until_ready()
print(f"compile + first run: {time.time()-t:.2f}s")

# 두 번째 — pure run
t = time.time()
y = step(x).block_until_ready()
print(f"warm run: {time.time()-t:.4f}s")

compile 이 너무 길면 — function 을 더 작게 쪼개거나, Python loop unrolling 을 scan 으로 대체.

memory 사용량

# peak memory 측정
peak = jax.live_arrays()
total_bytes = sum(arr.nbytes for arr in peak)
print(f"live array memory: {total_bytes / 1e9:.2f} GB")

OOM 발생 시 — 어떤 array 가 살아있는지 확인. 보통 — checkpoint 안 한 activation, 큰 dataset 이 device 에 머물러 있는 경우.

backend 정보

print(jax.devices())
print(jax.default_backend())
print(jax.local_device_count())

# XLA HLO 보기 — compile 결과의 IR
print(jax.xla_computation(step)(x).as_hlo_text())

HLO IR — 어떤 op 으로 compile 되었는지 정확히. fusion 이 의도대로 되었는지 확인 가능.

📊 profiling 우선순위

(1) 학습 step 시간 측정 (block_until_ready) — 1차 신호. (2) tensorboard trace — host vs device 시간. (3) 만약 host 가 병목 — data loader, host ↔ device transfer. (4) device 가 병목 — op profile, fusion 검증. (5) compile 이 길면 — function 분할, scan 사용. 근거 없는 micro-optimization 보다 — profile 결과에 따라.

경험적 — 학습 코드의 첫 작성 후 — 80% 의 case 는 profile 통해 1.5-3x 속도 향상 가능. 그 이후는 — model 자체 구조 / hardware 한계.

Code

import jax
import jax.numpy as jnp

# Check what's being compiled
@jax.jit
def my_function(x):
    return jnp.sin(x) + jnp.cos(x)

# Inspect the jaxpr (JAX's intermediate representation)
jaxpr = jax.make_jaxpr(my_function)(jnp.ones(3))
print(jaxpr)
# Shows the operations JAX will compile — useful for checking
# that your function isn't doing unexpected work

# Inspect compiled HLO
compiled = jax.jit(my_function).lower(jnp.ones(3)).compile()
print(compiled.cost_analysis())
# Shows estimated FLOPs, memory usage, etc.
import jax

# Profile with TensorBoard integration
jax.profiler.start_trace('/tmp/jax_profile')

# Run your training code
for step in range(100):
    params, loss = train_step(params, batch)

jax.profiler.stop_trace()
# Then: tensorboard --logdir=/tmp/jax_profile

# Or use the context manager
with jax.profiler.trace('/tmp/jax_profile'):
    for step in range(100):
        params, loss = train_step(params, batch)
# Set this to log when JIT recompiles
jax.config.update("jax_log_compiles", True)

# Now you'll see messages like:
# "Compiling my_function (..." every time JIT compiles a new variant

# Check XLA compilation logs for detailed info
# JAX_LOG_COMPILES=1 python my_script.py
# Timing best practices
import time

@jax.jit
def fast_fn(x):
    return jnp.linalg.svd(x, full_matrices=False)

x = jax.random.normal(jax.random.key(0), (1000, 500))

# First call includes compilation time
start = time.time()
result = fast_fn(x)
jax.block_until_ready(result)
print(f"First call (includes compile): {time.time() - start:.4f}s")

# Second call is the real runtime
start = time.time()
result = fast_fn(x)
jax.block_until_ready(result)
print(f"Second call (actual runtime): {time.time() - start:.4f}s")

External links

Exercise

50-step 학습 run 을 jax.profiler.start_trace 로 profile. trace 를 TensorBoard 또는 perfetto 에서 열기. 가장 느린 op 1 개 찾기. 한 가지만 최적화 — matmul 재배치 정도라도. profiling 먼저, 최적화 다음.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.