Random 은 별도 모듈, key 기반: jax.random.PRNGKey + jax.random.normal(key, ...).
Device 자동 배치: 가속기로 자동.
⚠️ 자주 나는 버그
NumPy 의 random seed 기반 코드를 JAX 로 옮기면 — random 이 매번 같은 값을 뱉어서 학습이 안 됨. jax.random.PRNGKey + split 으로 처음부터 다시 짜야 함.
기능 | NumPy | JAX
-------------------|-------------------|------------------
mutate | a[0] = 1 | a = a.at[0].set(1)
default float | float64 | float32
random | np.random (전역) | jax.random + key
device | RAM 만 | CPU/GPU/TPU 자동
Code
import numpy as np
import jax.numpy as jnp
# NumPy: in-place mutation works fine
np_arr = np.array([1, 2, 3])
np_arr[0] = 99
print(np_arr) # [99, 2, 3]
# JAX: in-place mutation raises an error
jnp_arr = jnp.array([1, 2, 3])
# jnp_arr[0] = 99 # ERROR: JAX arrays are immutable
# Instead, use .at[].set() to create a NEW array
new_arr = jnp_arr.at[0].set(99)
print(new_arr) # [99, 2, 3]
print(jnp_arr) # [1, 2, 3] — original unchanged!
import jax.numpy as jnp
x = jnp.array([10, 20, 30, 40, 50])
# Set a value
x_new = x.at[2].set(99) # [10, 20, 99, 40, 50]
# Add to a value
x_add = x.at[2].add(5) # [10, 20, 35, 40, 50]
# Multiply
x_mul = x.at[2].mul(2) # [10, 20, 60, 40, 50]
# Slice updates
x_slice = x.at[1:3].set(0) # [10, 0, 0, 40, 50]
# Conditional update with jnp.where
mask = x > 25
x_where = jnp.where(mask, x * 2, x) # [10, 20, 60, 80, 100]
import jax
import jax.numpy as jnp
# Check available devices
print(jax.devices()) # e.g., [CudaDevice(id=0)] or [TpuDevice(id=0)]
print(jax.default_backend()) # 'gpu', 'tpu', or 'cpu'
# Arrays are created on the default device
x = jnp.array([1.0, 2.0, 3.0])
print(x.devices()) # Shows which device(s) the array is on
# Explicitly place on a device
cpu_device = jax.devices('cpu')[0]
x_cpu = jax.device_put(x, cpu_device)
import numpy as np
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
# This works but BYPASSES JIT — the array is silently moved to CPU
result = np.sum(x) # Uses NumPy, not JAX — no JIT, no GPU
# Always use jnp for JAX arrays
result = jnp.sum(x) # Correct — uses JAX, can be JIT-compiled
in-place mutation (a[mask] = 0) 하는 5 줄 NumPy snippet. JAX 로 3 가지 방식: (1) jnp.where, (2) .at[mask].set(0), (3) functional helper. jit 안에서 깨끗히 compile 되는 거 확인. 셋 다 저장 — 패턴이 어디서나 반복.
Progress
Progress is local-only — sign in to sync across devices.