C.W.K.
Stream
Lesson 04 of 05 · published

JIT + Sharding 아래의 디버깅

~11 min · advanced, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

jit 이 들어가면 — 평범한 print, breakpoint 가 안 동작. JAX 가 그걸 위해 별도 도구 제공.

jax.debug.print

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
    jax.debug.print("x has shape {s}, mean {m}", s=x.shape, m=jnp.mean(x))
    y = x * 2
    jax.debug.print("y[0] = {y0}", y0=y[0])
    return y

f(jnp.array([1., 2., 3.]))
# x has shape (3,), mean 2.0
# y[0] = 2.0

중요한 차이:

  • 일반 print — trace 시점 한 번만 도는, 출력은 Tracer object.
  • jax.debug.print — 실제 호출마다 작동, 실제 값 출력. JIT compile 안에 들어감.

jax.debug.breakpoint

@jax.jit
def f(x):
    y = x ** 2
    jax.debug.breakpoint()   # 여기서 멈춤, 변수 inspect 가능
    z = jnp.sin(y)
    return z

f(jnp.array([1., 2., 3.]))
# (Pdb)  p y
# array([1., 4., 9.])
# (Pdb)  p z
# *** NameError: ...  (아직 안 만들어짐)

jit 안에서도 pdb 비슷한 interactive 디버그.

jax.debug.callback

임의 Python 함수 호출 (예: numpy 작업, plotting):

def log_to_disk(arr):
    np.save(f"step_{arr.shape}.npy", arr)

@jax.jit
def f(x):
    jax.debug.callback(log_to_disk, x)   # x 의 실제 값 받아 외부 함수
    return x ** 2

주의 — callback 은 — jit 의 일부가 아니라 host 콜. 매 호출마다 host ↔ device 동기화. 학습 loop 에 무자비하게 넣으면 성능 폭락. logging 이나 checkpoint 정도가 적당.

shape / dtype 추적 — eval_shape

def f(x, y):
    return jnp.einsum("ij,jk->ik", x, y)

# 함수 안 돌리고 shape 만
output_shape = jax.eval_shape(f, jnp.zeros((3, 4)), jnp.zeros((4, 5)))
print(output_shape)
# ShapeDtypeStruct(shape=(3, 5), dtype=float32)

큰 모델의 shape 검증 — 메모리 안 잡고 가능.

NaN 추적 — debug_nans

import jax
jax.config.update("jax_debug_nans", True)

# 이제 — 어디서든 NaN 나면 즉시 에러. 학습 발산 발견에 결정적.

또는 환경변수 JAX_DEBUG_NANS=1.

sharding 디버깅 — visualize_array_sharding

import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import numpy as np

devices = np.array(jax.devices()[:8]).reshape(4, 2)
mesh = Mesh(devices, axis_names=("data", "model"))

x = jax.device_put(jnp.zeros((1024, 512)), NamedSharding(mesh, P("data", "model")))
jax.debug.visualize_array_sharding(x)
# +---+---+
# |d0 |d1 |   ← row 0 (data axis 0)
# +---+---+
# |d2 |d3 |
# ...

큰 array 가 어떻게 분산되어 있는지 시각화. 분산 학습 디버깅에 필수.

💡 점진적 디버깅

(1) jit 빼고 eager 로 — pure Python 처럼. (2) jax.debug.print 추가 — 흐름 확인. (3) jax.debug.breakpoint — 인터랙티브. (4) jax.config.update("jax_debug_nans", True) — 발산 위치. (5) eval_shape — shape 만 빠르게. (6) visualize_array_sharding — 분산 구조. 상황에 따라 도구 선택.

JAX 디버깅의 mental model — "trace 시점 vs run 시점 분리". 이 둘이 헷갈리지 않으면 — 거의 모든 JIT bug 가 풀려.

Code

import jax
import jax.numpy as jnp

@jax.jit
def buggy_function(x):
    # Regular print: only runs during tracing!
    print("This prints ONCE during tracing, not during execution")

    # jax.debug.print: runs during execution
    jax.debug.print("x = {x}", x=x)

    y = x ** 2
    jax.debug.print("y = {y}", y=y)

    # You can print inside conditions, scans, etc.
    z = jax.lax.cond(
        x.sum() > 0,
        lambda: (jax.debug.print("positive branch"), x * 2)[1],
        lambda: (jax.debug.print("negative branch"), x * -1)[1],
    )
    return z

result = buggy_function(jnp.array([1.0, -2.0, 3.0]))
# Prints:
# x = [1. -2. 3.]
# y = [1. 4. 9.]
# positive branch
@jax.jit
def function_with_breakpoint(x):
    y = jnp.sin(x)
    jax.debug.breakpoint()  # drops into pdb during execution
    z = jnp.cos(y)
    return z

# When this runs, you'll get an interactive pdb prompt
# where you can inspect y, x, etc.
# result = function_with_breakpoint(jnp.array(1.0))
import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
import jax.numpy as jnp

# Create a mesh of devices
devices = jax.devices()  # all available devices
mesh = Mesh(devices, axis_names=('data',))

# Or for multi-dimensional parallelism:
# devices = mesh_utils.create_device_mesh((2, 4))
# mesh = Mesh(devices, axis_names=('data', 'model'))

# Shard data across the 'data' axis
data_sharding = NamedSharding(mesh, P('data'))

# Replicate across all devices
replicated = NamedSharding(mesh, P())

# Note: jax.P is a convenient alias for PartitionSpec (since JAX 0.7.0)
data_sharding_v2 = NamedSharding(mesh, jax.P('data'))

# Place data on devices with sharding
x = jax.device_put(jnp.ones((1024, 256)), data_sharding)

# JIT with output sharding
@jax.jit
def forward(params, x):
    return x @ params

# Shard params: replicate across data axis
params = jax.device_put(jnp.ones((256, 128)), replicated)
output = forward(params, x)
print(output.sharding)  # shows how output is distributed
# Mixed precision is straightforward in JAX
@jax.jit
def train_step_bf16(params, x, y):
    def loss_fn(params):
        # Cast to bfloat16 for the forward pass
        x_bf16 = x.astype(jnp.bfloat16)
        logits = forward_bf16(params, x_bf16)
        return jnp.mean((logits.astype(jnp.float32) - y) ** 2)

    loss, grads = jax.value_and_grad(loss_fn)(params)
    # Gradients are in float32 — optimizer updates in full precision
    return loss, grads

External links

Exercise

subtle bug (indexing off-by-one) 의 jit'd 함수. jax.debug.print 로 trace 안 깨고 intermediate shape/value 확인. jax.debug.breakpoint() 설정 후 step. jit 아래의 디버깅이 자기 만의 skill.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.