import jax.numpy as jnp
a = jnp.zeros((3, 3))
# 새로운 배열 반환 (원본은 안 변함)
b = a.at[1, 2].set(7.0)
c = a.at[:, 0].add(1.0)
d = a.at[a > 0].set(0.0)
💡 jnp.where 가 친구야
boolean indexing 으로 update 하고 싶을 때 — jnp.where(cond, a, b) 가 거의 항상 답이야. x = jnp.where(x < 0, 0, x) 가 ReLU.
한 마디 — JAX 의 indexing 은 "읽을 땐 NumPy, 쓸 땐 .at, 분기는 jnp.where". 이 패턴이 손에 익으면 NumPy 코드를 JAX 로 옮기는 게 거의 자동.
Code
import jax.numpy as jnp
x = jnp.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# Basic indexing
print(x[0, 1]) # 2
print(x[1]) # [4, 5, 6]
print(x[:, 0]) # [1, 4, 7]
# Slicing
print(x[0:2, 1:3]) # [[2, 3], [5, 6]]
print(x[::2]) # [[1, 2, 3], [7, 8, 9]] — every other row
# Boolean indexing (works, but has caveats under jit)
mask = x > 5
print(x[mask]) # [6, 7, 8, 9]
# Fancy indexing
indices = jnp.array([0, 2])
print(x[indices]) # [[1, 2, 3], [7, 8, 9]]
import jax.numpy as jnp
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])
# Arithmetic
print(a + b) # [5. 7. 9.]
print(a * b) # [4. 10. 18.]
print(a / b) # [0.25 0.4 0.5]
print(a ** 2) # [1. 4. 9.]
# Math functions
print(jnp.sin(a)) # [0.841 0.909 0.141]
print(jnp.exp(a)) # [2.718 7.389 20.086]
print(jnp.log(a)) # [0. 0.693 1.099]
print(jnp.sqrt(a)) # [1. 1.414 1.732]
# Broadcasting — just like NumPy
matrix = jnp.ones((3, 4))
row_vec = jnp.array([1.0, 2.0, 3.0, 4.0])
print((matrix + row_vec).shape) # (3, 4) — row added to each row
x = jnp.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
print(jnp.sum(x)) # 21.0 — sum of all elements
print(jnp.sum(x, axis=0)) # [5. 7. 9.] — sum along rows
print(jnp.sum(x, axis=1)) # [6. 15.] — sum along columns
print(jnp.mean(x)) # 3.5
print(jnp.max(x, axis=1)) # [3. 6.]
print(jnp.argmax(x, axis=1))# [2, 2] — indices of max values
A = jnp.array([[1.0, 2.0], [3.0, 4.0]])
B = jnp.array([[5.0, 6.0], [7.0, 8.0]])
# Matrix multiplication — three equivalent ways
print(jnp.matmul(A, B))
print(jnp.dot(A, B))
print(A @ B)
# All give: [[19. 22.], [43. 50.]]
# Transpose
print(A.T) # [[1. 3.], [2. 4.]]
print(jnp.transpose(A)) # Same
# Batched matmul — works with higher-dimensional arrays
batch_A = jnp.ones((10, 3, 4))
batch_B = jnp.ones((10, 4, 5))
result = batch_A @ batch_B
print(result.shape) # (10, 3, 5)