jit 이 들어가면 — 평범한 print, breakpoint 가 안 동작. JAX 가 그걸 위해 별도 도구 제공.
jax.debug.print
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("x has shape {s}, mean {m}", s=x.shape, m=jnp.mean(x))
y = x * 2
jax.debug.print("y[0] = {y0}", y0=y[0])
return y
f(jnp.array([1., 2., 3.]))
# x has shape (3,), mean 2.0
# y[0] = 2.0
중요한 차이:
- 일반
print— trace 시점 한 번만 도는, 출력은 Tracer object. jax.debug.print— 실제 호출마다 작동, 실제 값 출력. JIT compile 안에 들어감.
jax.debug.breakpoint
@jax.jit
def f(x):
y = x ** 2
jax.debug.breakpoint() # 여기서 멈춤, 변수 inspect 가능
z = jnp.sin(y)
return z
f(jnp.array([1., 2., 3.]))
# (Pdb) p y
# array([1., 4., 9.])
# (Pdb) p z
# *** NameError: ... (아직 안 만들어짐)
jit 안에서도 pdb 비슷한 interactive 디버그.
jax.debug.callback
임의 Python 함수 호출 (예: numpy 작업, plotting):
def log_to_disk(arr):
np.save(f"step_{arr.shape}.npy", arr)
@jax.jit
def f(x):
jax.debug.callback(log_to_disk, x) # x 의 실제 값 받아 외부 함수
return x ** 2
주의 — callback 은 — jit 의 일부가 아니라 host 콜. 매 호출마다 host ↔ device 동기화. 학습 loop 에 무자비하게 넣으면 성능 폭락. logging 이나 checkpoint 정도가 적당.
shape / dtype 추적 — eval_shape
def f(x, y):
return jnp.einsum("ij,jk->ik", x, y)
# 함수 안 돌리고 shape 만
output_shape = jax.eval_shape(f, jnp.zeros((3, 4)), jnp.zeros((4, 5)))
print(output_shape)
# ShapeDtypeStruct(shape=(3, 5), dtype=float32)
큰 모델의 shape 검증 — 메모리 안 잡고 가능.
NaN 추적 — debug_nans
import jax
jax.config.update("jax_debug_nans", True)
# 이제 — 어디서든 NaN 나면 즉시 에러. 학습 발산 발견에 결정적.
또는 환경변수 JAX_DEBUG_NANS=1.
sharding 디버깅 — visualize_array_sharding
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import numpy as np
devices = np.array(jax.devices()[:8]).reshape(4, 2)
mesh = Mesh(devices, axis_names=("data", "model"))
x = jax.device_put(jnp.zeros((1024, 512)), NamedSharding(mesh, P("data", "model")))
jax.debug.visualize_array_sharding(x)
# +---+---+
# |d0 |d1 | ← row 0 (data axis 0)
# +---+---+
# |d2 |d3 |
# ...
큰 array 가 어떻게 분산되어 있는지 시각화. 분산 학습 디버깅에 필수.
💡 점진적 디버깅
(1) jit 빼고 eager 로 — pure Python 처럼. (2) jax.debug.print 추가 — 흐름 확인. (3) jax.debug.breakpoint — 인터랙티브. (4) jax.config.update("jax_debug_nans", True) — 발산 위치. (5) eval_shape — shape 만 빠르게. (6) visualize_array_sharding — 분산 구조. 상황에 따라 도구 선택.
JAX 디버깅의 mental model — "trace 시점 vs run 시점 분리". 이 둘이 헷갈리지 않으면 — 거의 모든 JIT bug 가 풀려.