C.W.K.
Stream
Lesson 03 of 06 · published

JAX vs PyTorch vs TensorFlow

~12 min · origins, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

"근데 PyTorch 가 모든 걸 다 하는데 왜 JAX 야?" — 합리적 질문이야. 답은: 다른 패러다임이고, 각자 다른 것에 강해.

PyTorch: dynamic, eager, OOP. nn.Module 상속, forward 메서드, .backward() 호출하면 magic 처럼 gradient 가 parameter 에 붙음. 처음 배우는 사람한테 친절. 단점은 — graph 구조가 명시적이지 않아서 compiler 최적화 제한.

TensorFlow: 원래 graph-first 였는데 2.x 부터 eager mode default. 산업 인프라 (TFX, TF Serving, TF Lite, TFJS) 강력. 연구실에선 점점 덜 쓰는 추세.

JAX: functional, transformation-first. nn.Module 같은 거 없음 (Flax/Equinox 가 따로 제공). jit/grad/vmap/pmap 이 1 등 시민, 자유롭게 합성. compile-first 라서 XLA 최적화 깊음. 단점은 — 처음 진입 장벽이 있어 (functional 사고, pure function, PRNG key 등).

         | PyTorch          | TensorFlow       | JAX
---------|------------------|------------------|------------------
스타일   | OOP, eager       | OOP/graph hybrid | functional
미분     | tensor.backward()| tape 기반        | jax.grad (함수)
batch    | 손으로 처리      | 손으로 처리      | jax.vmap
multi-GPU| DDP / FSDP       | tf.distribute    | pmap / sharding
연구     | 1 위 (분야 다수) | 감소             | 상승 (DeepMind 등)

🧭 어느 걸 골라야 하나

처음이면 PyTorch 부터. 큰 회사 production 이면 TF 도 여전히. 연구 — 특히 functional / transform-heavy / TPU 활용 — 면 JAX. 답은 "둘 다 알면 좋음" 인데, 이 quest 는 JAX 의 왜 를 가르치는 게 목표야.

중요한 건 — JAX 가 PyTorch 를 죽이러 온 게 아니야. 같은 문제에 다른 답을 제시하는 거. 두 답이 공존하는 시대를 살고 있어.

Code

# PyTorch style (for comparison)
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 1)

    def forward(self, x):
        return self.linear(x)

model = Model()
x = torch.tensor([[1.0, 2.0]])
y = model(x)          # Eager execution, tape is recording
loss = y.sum()
loss.backward()        # Replay tape to get gradients
print(model.linear.weight.grad)
# JAX style
import jax
import jax.numpy as jnp

def predict(params, x):
    w, b = params
    return jnp.dot(x, w) + b

def loss_fn(params, x, y):
    pred = predict(params, x)
    return jnp.mean((pred - y) ** 2)

# Parameters are just arrays in a tuple — no special Variable type
params = (jnp.array([0.5, 0.3]), jnp.array(0.1))
x = jnp.array([[1.0, 2.0]])
y = jnp.array([1.5])

# Gradient is a function, not a method on a loss object
grads = jax.grad(loss_fn)(params, x, y)
print(grads)  # Tuple of gradient arrays matching params structure

External links

Exercise

같은 single-layer regression 을 PyTorch 와 JAX 로. 100 step 학습, 합성 데이터. 시간 측정. API 차이 적기 — gradient state 의 주인, device 관리, loop 의 가독성. 우월 가리지 말고 — JAX 가 더 단순한 점 3 가지, PyTorch 가 더 단순한 점 3 가지.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.