jax.numpy 는 단순한 한 가지 약속이야 — "NumPy 의 API surface 를 거의 그대로 흉내낼게, 다만 실행은 XLA 위에서". 그래서 시작은 import 한 줄 바꾸기.
import numpy as np
import jax.numpy as jnp
a = np.array([1, 2, 3])
b = jnp.array([1, 2, 3])
print(np.sum(a)) # 6
print(jnp.sum(b)) # 6
# 거의 모든 게 그대로 — 함수 이름, 시그니처, 동작
np.zeros((3, 3)); jnp.zeros((3, 3))
np.linspace(0,1,5); jnp.linspace(0,1,5)
np.dot(a, a); jnp.dot(b, b)
차이점은 실행 모델에서:
- immutable: jnp.array 변경 불가. a[0] = 5 안 됨. 대신 a = a.at[0].set(5).
- device 자동 배치: 만들어지는 즉시 가속기로.
- dtype default 가 float32: NumPy 는 float64.
- random 은 별도: jax.random + key. Track 8 에서.
⚠️ 첫 번째 함정
NumPy 코드를 그대로 복붙하면 — index 할당이 거의 무조건 깨져. .at[].set() 패턴을 손에 익혀야 해.