C.W.K.
Stream
Lesson 01 of 06 · published

jax.numpy: NumPy 호환 API

~8 min · numpy, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

jax.numpy 는 단순한 한 가지 약속이야 — "NumPy 의 API surface 를 거의 그대로 흉내낼게, 다만 실행은 XLA 위에서". 그래서 시작은 import 한 줄 바꾸기.

import numpy as np
import jax.numpy as jnp

a = np.array([1, 2, 3])
b = jnp.array([1, 2, 3])

print(np.sum(a))   # 6
print(jnp.sum(b))  # 6

# 거의 모든 게 그대로 — 함수 이름, 시그니처, 동작
np.zeros((3, 3));   jnp.zeros((3, 3))
np.linspace(0,1,5); jnp.linspace(0,1,5)
np.dot(a, a);       jnp.dot(b, b)

차이점은 실행 모델에서:

  • immutable: jnp.array 변경 불가. a[0] = 5 안 됨. 대신 a = a.at[0].set(5).
  • device 자동 배치: 만들어지는 즉시 가속기로.
  • dtype default 가 float32: NumPy 는 float64.
  • random 은 별도: jax.random + key. Track 8 에서.

⚠️ 첫 번째 함정

NumPy 코드를 그대로 복붙하면 — index 할당이 거의 무조건 깨져. .at[].set() 패턴을 손에 익혀야 해.

Code

import numpy as np
import jax.numpy as jnp

# Creating arrays — identical syntax
np_arr = np.array([1.0, 2.0, 3.0])
jnp_arr = jnp.array([1.0, 2.0, 3.0])

# Zeros, ones, ranges — identical
np_zeros = np.zeros((3, 4))
jnp_zeros = jnp.zeros((3, 4))

np_range = np.arange(0, 10, 2)
jnp_range = jnp.arange(0, 10, 2)

np_lin = np.linspace(0, 1, 50)
jnp_lin = jnp.linspace(0, 1, 50)

# Math — identical
print(np.sum(np_arr ** 2))     # 14.0
print(jnp.sum(jnp_arr ** 2))  # 14.0
import jax.numpy as jnp

# From Python data
a = jnp.array([[1, 2], [3, 4]])

# Filled arrays
zeros = jnp.zeros((3, 4))           # 3x4 of zeros
ones = jnp.ones((2, 3))             # 2x3 of ones
full = jnp.full((2, 2), fill_value=7.0)  # 2x2 of 7.0

# Ranges
r = jnp.arange(0, 10, 0.5)          # [0.0, 0.5, 1.0, ..., 9.5]
l = jnp.linspace(0, 1, 5)           # [0.0, 0.25, 0.5, 0.75, 1.0]

# Identity and diagonal
eye = jnp.eye(3)                     # 3x3 identity matrix
diag = jnp.diag(jnp.array([1, 2, 3]))  # 3x3 diagonal matrix

# Random arrays (JAX uses explicit PRNG keys, covered in Track 3)
import jax
key = jax.random.PRNGKey(0)
rand = jax.random.normal(key, shape=(3, 4))  # 3x4 standard normal
x = jnp.array([[1.0, 2.0, 3.0],
                [4.0, 5.0, 6.0]])

print(x.shape)   # (2, 3)
print(x.dtype)   # float32
print(x.ndim)    # 2
print(x.size)    # 6

External links

Exercise

이전에 짠 30-line NumPy script (아무거나). import numpy as np 를 import jax.numpy as jnp 로, np. 를 jnp. 로 바꿔. 실행. 어떤 호출이 에러나고 어떤 게 조용히 동작이 바뀌는지 정확히 적기. 아직 고치지 마 — 진단이 lesson.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.