JAX 의 default dtype 은 float32. NumPy 는 float64. 이 차이를 모르면 — 같은 코드인데 답이 다르게 나오는 일 생겨.
import numpy as np
import jax.numpy as jnp
a_np = np.array([1.0, 2.0, 3.0])
a_jax = jnp.array([1.0, 2.0, 3.0])
print(a_np.dtype) # float64
print(a_jax.dtype) # float32 (!)
왜 float32 가 default 냐 — 가속기 (GPU/TPU) 에서 float32 가 보통 float64 의 2~30 배 빨라. ML workload 는 거의 다 float32 로 충분히 수렴해.
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
a = jnp.array([1.0, 2.0])
print(a.dtype) # float64 (이제 됨)
⚠️ 정밀도 함정
NumPy 에서 float64 로 동작하던 코드를 JAX 로 옮기면 — 같은 답이 안 나올 수 있어. 학습이 안 되거나 NaN 나면 가장 먼저 dtype 체크.