C.W.K.
Stream
Lesson 03 of 06 · published

MLX vs PyTorch MPS — 통역사 vs 원어민

~14 min · pytorch-mps, comparison, performance

Level 0Curious
0 XP0/51 lessons0/15 achievements
0/100 XP to next level100 XP to go0% complete

통역사 vs 원어민, 구체적으로

PyTorch MPS 가 NVIDIA 하드웨어 위해 박힌 PyTorch 코드가 Apple Silicon 의 Metal backend 에서 돌게 하는 CUDA-API 번역 layer. 모든 PyTorch op 이 그 Metal Performance Shader 등가물에 dispatch (또는 MPS 구현 없으면 CPU 로 fallback). 번역 동작; GPU 가 자기 메모리 가진 별도 device 라는 세계관 물려받을 뿐.

MLX 는 Apple Silicon 위해 직접 박힌 framework. Unified-memory 모델, lazy graph, function transform — 이 중 어느 것도 기존 API 위에 retrofit 안 됨; framework 의 design center.

실제로 이게 어떻게 보이나

  • API 표면 — PyTorch MPS 는 네가 아는 같은 torch.tensor, tensor.to('mps'), autograd-by-tape 패턴 사용. MLX 는 mx.array, .to() 없음, autograd-by-function-transform. 다른 비용에 다른 모양.
  • 성능 — 같은 모델에 PyTorch MPS 와 MLX 사이 갭이 small to moderate (가끔 한쪽 이김, 가끔 반대). 릴리스 사이 리더 바뀜. specific 워크로드 측정하지 않으면 성능으로 안 골라.
  • Coverage — PyTorch MPS 가 Metal 의 PyTorch op 100% 지원 안 함. 갭이 매 릴리스 닫히지만 여전히 조용히 CPU 로 fallback 하는 op hit 가능, 성능 죽임. MLX 는 더 작고 완전 지원 표면 가짐 — PyTorch 가 가진 같은 모델 안 받지만 받는 거는 동작.
  • 메모리 모델 — PyTorch MPS 가 Apple Silicon 이 안 필요해도 .to(device) 의식이 API 에 baked. MLX 의 API 가 하드웨어와 일치.

PyTorch MPS 가 맞는 호출일 때

  • 기존 PyTorch codebase 가 있고 재작성 없이 Mac 에서 로컬 개발 원함. Drop-in CUDA → MPS swap 이 PyTorch MPS 가 존재하는 이유.
  • PyTorch-specific 라이브러리 의존 (HuggingFace transformers, specific torch-only 모델). 아직 MLX 로 다 port 안 됨.
  • Cloud GPU 학습과 bug-for-bug 호환성 원함 — Mac 에서 prototype 하고 NVIDIA 에 배포할 수 있게, 동작 surprise 없이.

MLX 가 맞는 호출일 때

  • Apple Silicon 에 새로 시작하고 honor 할 PyTorch legacy 없음.
  • Function-transform 스타일 원함 (JAX-flavor — mx.grad, mx.vmap, 쉬운 compile).
  • GPU 빌리지 않고 LLM 로컬 fine-tune 원함 (mlx-lm 의 LoRA 워크플로가 PyTorch MPS 의 등가물보다 더 polish).
  • Mac-전용 배포에 출하하고 native-shape API 원함.

솔직한 중간 지점

호스티드 GPU 클러스터와 Mac 사이 bounce 하는 연구엔, PyTorch MPS 가 코드 통합 유지. Mac-전용 작업, 특히 Mac-전용 LLM 워크플로엔, MLX 가 다른 사람의 하드웨어 위해 디자인 안 된 framework.

Code

PyTorch MPS 와 MLX 의 같은 matmul·python
# PyTorch MPS
import torch
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
x = torch.randn(1024, 1024, device=device)
y = x @ x.T
print("PyTorch MPS:", tuple(y.shape), y.dtype, y.device)

# MLX
import mlx.core as mx
x_mlx = mx.random.normal((1024, 1024))
y_mlx = x_mlx @ x_mlx.T
mx.eval(y_mlx)
print("MLX        :", tuple(y_mlx.shape), y_mlx.dtype, mx.default_device())

# Verified outputs (2026-05-03):
#   PyTorch MPS: (1024, 1024) torch.float32 mps:0
#   MLX        : (1024, 1024) mlx.core.float32 Device(gpu, 0)

External links

Exercise

Env 에 torch 깔려 있으면, PyTorch MPS 와 MLX 에 같은 matmul 돌려 (이 레슨의 코드 블록이 둘 다 함). time.perf_counter() 로 timing. 갭 알아채. 그 다음 4096×4096 matmul 로 바꾸고 다시 timing. 갭이 자라, 줄어, 비슷하게 머무? 알아챈 거 두 문장.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.