학습이 느릴 때 — 무엇이 병목인가? JAX 가 강력한 profiling 도구 제공.
jax.profiler — TensorBoard trace
import jax
import jax.numpy as jnp
# trace 시작
jax.profiler.start_trace("/tmp/profile-data")
# 측정할 코드
for _ in range(100):
state, loss = train_step(state, batch)
state.params["w"].block_until_ready()
# trace 종료
jax.profiler.stop_trace()
그 후 — TensorBoard 로 열기:
pip install tensorboard tensorboard-plugin-profile
tensorboard --logdir /tmp/profile-data
# Profile 탭 → trace_viewer 또는 op_profile
각 op 의 시간, device 활용도, host ↔ device 통신 등 가시화. 학습 step 의 병목을 찾는 데 결정적.
trace_annotation — 코드의 부분에 이름 부여
@jax.profiler.annotate_function
def model_forward(params, x):
return ...
# 또는 context manager
with jax.profiler.TraceAnnotation("data_loading"):
batch = next(loader)
with jax.profiler.TraceAnnotation("train_step"):
state = train_step(state, batch)
profile 결과에서 — 명명된 부분이 별도 표시. "어디가 느린가" 가 시각적.
perfetto / Chrome trace
TensorBoard 가 무거우면 — 더 가벼운 perfetto:
jax.profiler.start_trace("/tmp/perfetto-trace")
# code
jax.profiler.stop_trace()
https://ui.perfetto.dev/ 에서 직접 trace 파일 열기 가능.
compile 시간 분석
import time
@jax.jit
def step(x):
# 복잡한 계산
return ...
# 첫 호출 — compile
t = time.time()
y = step(x).block_until_ready()
print(f"compile + first run: {time.time()-t:.2f}s")
# 두 번째 — pure run
t = time.time()
y = step(x).block_until_ready()
print(f"warm run: {time.time()-t:.4f}s")
compile 이 너무 길면 — function 을 더 작게 쪼개거나, Python loop unrolling 을 scan 으로 대체.
memory 사용량
# peak memory 측정
peak = jax.live_arrays()
total_bytes = sum(arr.nbytes for arr in peak)
print(f"live array memory: {total_bytes / 1e9:.2f} GB")
OOM 발생 시 — 어떤 array 가 살아있는지 확인. 보통 — checkpoint 안 한 activation, 큰 dataset 이 device 에 머물러 있는 경우.
backend 정보
print(jax.devices())
print(jax.default_backend())
print(jax.local_device_count())
# XLA HLO 보기 — compile 결과의 IR
print(jax.xla_computation(step)(x).as_hlo_text())
HLO IR — 어떤 op 으로 compile 되었는지 정확히. fusion 이 의도대로 되었는지 확인 가능.
📊 profiling 우선순위
(1) 학습 step 시간 측정 (block_until_ready) — 1차 신호. (2) tensorboard trace — host vs device 시간. (3) 만약 host 가 병목 — data loader, host ↔ device transfer. (4) device 가 병목 — op profile, fusion 검증. (5) compile 이 길면 — function 분할, scan 사용. 근거 없는 micro-optimization 보다 — profile 결과에 따라.
경험적 — 학습 코드의 첫 작성 후 — 80% 의 case 는 profile 통해 1.5-3x 속도 향상 가능. 그 이후는 — model 자체 구조 / hardware 한계.