C.W.K.
Stream
Lesson 07 of 07 · published

mx.compile — 측정 가능한 kernel fusion

~14 min · mx.compile, performance, fusion

Level 0Curious
0 XP0/51 lessons0/15 achievements
0/100 XP to next level100 XP to go0% complete

mx.compile 가 진짜 뭐 하나

mx.compile 가 함수를 감싸고 새 함수 돌려줘. 첫 호출에서 op 들을 graph 로 trace 하고 단일 (또는 더 작은 셋의) Metal kernel 로 컴파일. 같은 입력 shape 와 dtype 으로 모든 후속 호출에서, tracing 건너뛰고 컴파일된 kernel 을 직접 dispatch. 함수가 hot 일 때 결과 — 더 적은 kernel launch 와 메모리에 쓸 필요 없는 intermediate 의 더 빡빡한 fusion.

이걸 쓰는 데 Metal 이해할 필요 없어. 언제 쓸지 이해할 필요는 있어.

이 머신에서의 벤치마크

아래 코드 블록이 synthetic hot function 시간 재 — (512, 512) array 위 20 개 직렬 tanh op — 컴파일과 컴파일 없이 둘 다. 내 office Mac (M3 Ultra Studio, mlx 0.31.2, 2026-05-03) 에서:

  • Plain — ~0.97 ms / 호출
  • Compiled — ~0.31 ms / 호출
  • Speedup — ~3.18x

네 숫자는 머신마다, GPU 에 경쟁 중인 다른 거에 따라, 함수의 specific shape 에 따라 달라. 결과의 모양 — 의미 있지만 자릿수는 아닌 — 이 네가 기대해야 할 것.

언제 쓰나 (그리고 언제 안 쓰나)

— 입력 shape 가 안정인 hot, 반복 호출 함수에. Training-step 함수, inference-token-step 함수, 같은 shape signature 로 수백만 번 호출되는 무엇이든.

쓰지 마 — 몇 번만 호출하는 함수에. Tracing 과 컴파일 비용은 첫 호출에서 내고 많은 호출에 걸쳐 amortize. 일회성 계산엔 overhead 가 절약 초과.

조심 — 입력 shape 가 다양한 함수에. 모든 새 shape signature 가 recompile 트리거 — shape 가 계속 바뀌면 runtime 의 dominant. 한 번 컴파일되고 천 번 dispatch 되는 함수는 좋아. 천 번 recompile 되는 함수는 좋음의 반대.

안 하는 것

mx.compile 가 함수의 결과 안 바꿔 — 출력이 컴파일 안 한 버전과 bit-identical (또는 floating-point-equivalent). 그리고 모든 거 마법처럼 3x 빨라지지도 않아 — memory-bound op 이나 이미 단일 kernel 호출인 op 에는 speedup 작거나 0. 위 벤치마크의 3x 는 20 개 tanh op 을 한 kernel 로 fuse 한 데서 와. 이미 하나의 matmul 인 함수에선 거의 안 보일 거.

Code

Plain vs compiled — 직접 측정·python
import mlx.core as mx
import time


def heavy(x):
    y = x
    for _ in range(20):
        y = mx.tanh(y * 1.001 + 0.001)
    return y.sum()


heavy_compiled = mx.compile(heavy)

x = mx.random.normal((512, 512))

# Warm up both versions to amortize first-call costs (compile traces here).
mx.eval(heavy(x))
mx.eval(heavy_compiled(x))


def bench(fn, n=20):
    mx.eval(fn(x))   # warm
    t0 = time.perf_counter()
    for _ in range(n):
        r = fn(x)
        mx.eval(r)
    return (time.perf_counter() - t0) / n * 1000   # ms


t_plain = bench(heavy)
t_comp  = bench(heavy_compiled)

print(f'plain    : {t_plain:.3f} ms / call')
print(f'compiled : {t_comp:.3f} ms / call')
print(f'speedup  : {t_plain/t_comp:.2f}x')

# Verified on M3 Ultra Studio, mlx 0.31.2 (2026-05-03):
#   plain    : 0.969 ms / call
#   compiled : 0.305 ms / call
#   speedup  : 3.18x

External links

Exercise

네 머신에서 벤치마크 돌려. 그 다음 한 번에 한 변수 바꿔 다시 측정 — (a) array shape 를 (512, 512) 에서 (4096, 4096) 으로 바꿔 — speedup 유지돼, 자라, 줄어? (b) heavy 안 loop count 를 20 에서 1 로 바꿔 (즉 단일 op) — speedup 에 뭐 일어나? 가져갈 직관 — compile 은 fuse 할 작은 op 많고 shape 안정일 때 가장 도움.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.