C.W.K.
Stream
Lesson 03 of 06 · published

Indexing, Slicing, 수치 연산

~9 min · numpy, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

JAX 에서 array 를 읽는 방법은 NumPy 와 같아. 쓰는 방법이 다를 뿐.

import jax.numpy as jnp

a = jnp.zeros((3, 3))

# 새로운 배열 반환 (원본은 안 변함)
b = a.at[1, 2].set(7.0)
c = a.at[:, 0].add(1.0)
d = a.at[a > 0].set(0.0)

💡 jnp.where 가 친구야

boolean indexing 으로 update 하고 싶을 때 — jnp.where(cond, a, b) 가 거의 항상 답이야. x = jnp.where(x < 0, 0, x) 가 ReLU.

한 마디 — JAX 의 indexing 은 "읽을 땐 NumPy, 쓸 땐 .at, 분기는 jnp.where". 이 패턴이 손에 익으면 NumPy 코드를 JAX 로 옮기는 게 거의 자동.

Code

import jax.numpy as jnp

x = jnp.array([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])

# Basic indexing
print(x[0, 1])      # 2
print(x[1])          # [4, 5, 6]
print(x[:, 0])       # [1, 4, 7]

# Slicing
print(x[0:2, 1:3])   # [[2, 3], [5, 6]]
print(x[::2])         # [[1, 2, 3], [7, 8, 9]] — every other row

# Boolean indexing (works, but has caveats under jit)
mask = x > 5
print(x[mask])        # [6, 7, 8, 9]

# Fancy indexing
indices = jnp.array([0, 2])
print(x[indices])     # [[1, 2, 3], [7, 8, 9]]
import jax.numpy as jnp

a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])

# Arithmetic
print(a + b)       # [5. 7. 9.]
print(a * b)       # [4. 10. 18.]
print(a / b)       # [0.25 0.4  0.5]
print(a ** 2)      # [1. 4. 9.]

# Math functions
print(jnp.sin(a))        # [0.841 0.909 0.141]
print(jnp.exp(a))        # [2.718 7.389 20.086]
print(jnp.log(a))        # [0.    0.693 1.099]
print(jnp.sqrt(a))       # [1.    1.414 1.732]

# Broadcasting — just like NumPy
matrix = jnp.ones((3, 4))
row_vec = jnp.array([1.0, 2.0, 3.0, 4.0])
print((matrix + row_vec).shape)  # (3, 4) — row added to each row
x = jnp.array([[1.0, 2.0, 3.0],
                [4.0, 5.0, 6.0]])

print(jnp.sum(x))           # 21.0 — sum of all elements
print(jnp.sum(x, axis=0))   # [5. 7. 9.] — sum along rows
print(jnp.sum(x, axis=1))   # [6. 15.] — sum along columns
print(jnp.mean(x))          # 3.5
print(jnp.max(x, axis=1))   # [3. 6.]
print(jnp.argmax(x, axis=1))# [2, 2] — indices of max values
A = jnp.array([[1.0, 2.0], [3.0, 4.0]])
B = jnp.array([[5.0, 6.0], [7.0, 8.0]])

# Matrix multiplication — three equivalent ways
print(jnp.matmul(A, B))
print(jnp.dot(A, B))
print(A @ B)
# All give: [[19. 22.], [43. 50.]]

# Transpose
print(A.T)            # [[1. 3.], [2. 4.]]
print(jnp.transpose(A))  # Same

# Batched matmul — works with higher-dimensional arrays
batch_A = jnp.ones((10, 3, 4))
batch_B = jnp.ones((10, 4, 5))
result = batch_A @ batch_B
print(result.shape)   # (10, 3, 5)

External links

Exercise

5×5 array. row + col 이 odd 인 곳 모두 0 으로 — jnp.where 로. 같은 일을 .at[].set() functional update 로. 두 스타일 가독성 비교. jit 안에서 어느 게 먼저 손에 잡힐지?

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.