C.W.K.
Stream
Lesson 05 of 06 · published

JAX array 가 NumPy 와 다른 지점

~10 min · numpy, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

지금까지 본 차이 정리 — 4 가지 결정적 다름:

  1. Immutability: a = a.at[0].set(99) 만 됨.
  2. Default dtype 이 float32: x64 enable 안 하면.
  3. Random 은 별도 모듈, key 기반: jax.random.PRNGKey + jax.random.normal(key, ...).
  4. Device 자동 배치: 가속기로 자동.

⚠️ 자주 나는 버그

NumPy 의 random seed 기반 코드를 JAX 로 옮기면 — random 이 매번 같은 값을 뱉어서 학습이 안 됨. jax.random.PRNGKey + split 으로 처음부터 다시 짜야 함.

기능               | NumPy             | JAX
-------------------|-------------------|------------------
mutate             | a[0] = 1          | a = a.at[0].set(1)
default float      | float64           | float32
random             | np.random (전역)  | jax.random + key
device             | RAM 만            | CPU/GPU/TPU 자동

Code

import numpy as np
import jax.numpy as jnp

# NumPy: in-place mutation works fine
np_arr = np.array([1, 2, 3])
np_arr[0] = 99
print(np_arr)  # [99, 2, 3]

# JAX: in-place mutation raises an error
jnp_arr = jnp.array([1, 2, 3])
# jnp_arr[0] = 99  # ERROR: JAX arrays are immutable

# Instead, use .at[].set() to create a NEW array
new_arr = jnp_arr.at[0].set(99)
print(new_arr)   # [99, 2, 3]
print(jnp_arr)   # [1, 2, 3] — original unchanged!
import jax.numpy as jnp

x = jnp.array([10, 20, 30, 40, 50])

# Set a value
x_new = x.at[2].set(99)           # [10, 20, 99, 40, 50]

# Add to a value
x_add = x.at[2].add(5)            # [10, 20, 35, 40, 50]

# Multiply
x_mul = x.at[2].mul(2)            # [10, 20, 60, 40, 50]

# Slice updates
x_slice = x.at[1:3].set(0)        # [10, 0, 0, 40, 50]

# Conditional update with jnp.where
mask = x > 25
x_where = jnp.where(mask, x * 2, x)  # [10, 20, 60, 80, 100]
import jax
import jax.numpy as jnp

# Check available devices
print(jax.devices())           # e.g., [CudaDevice(id=0)] or [TpuDevice(id=0)]
print(jax.default_backend())   # 'gpu', 'tpu', or 'cpu'

# Arrays are created on the default device
x = jnp.array([1.0, 2.0, 3.0])
print(x.devices())  # Shows which device(s) the array is on

# Explicitly place on a device
cpu_device = jax.devices('cpu')[0]
x_cpu = jax.device_put(x, cpu_device)
import numpy as np
import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])

# This works but BYPASSES JIT — the array is silently moved to CPU
result = np.sum(x)  # Uses NumPy, not JAX — no JIT, no GPU

# Always use jnp for JAX arrays
result = jnp.sum(x)  # Correct — uses JAX, can be JIT-compiled

External links

Exercise

in-place mutation (a[mask] = 0) 하는 5 줄 NumPy snippet. JAX 로 3 가지 방식: (1) jnp.where, (2) .at[mask].set(0), (3) functional helper. jit 안에서 깨끗히 compile 되는 거 확인. 셋 다 저장 — 패턴이 어디서나 반복.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.