NumPy 의 in-place mutation 을 JAX 의 functional update 로 옮기는 패턴 — 매일 쓰게 됨.
핵심 도구가 .at[] indexer:
import jax.numpy as jnp
a = jnp.zeros((4, 4))
# .at[index].set(value) — 새 array 반환, 원본 안 변함
b = a.at[0, 0].set(1.0)
print(a) # 여전히 0
print(b) # (0,0)에 1
# 다른 update operation 들
a.at[0].add(5.0) # +=
a.at[0].multiply(2.0) # *=
a.at[0].divide(2.0) # /=
a.at[0].power(2.0) # **=
a.at[0].min(0.0) # element-wise min
a.at[0].max(1.0) # element-wise max
a.at[0].apply(jnp.sqrt) # 임의 함수 적용
모든 indexing 형태 지원:
a = jnp.zeros((5, 5))
# slice
b = a.at[1:3, :].set(jnp.ones((2, 5)))
# fancy indexing
b = a.at[[0, 2, 4]].set(1.0)
# boolean mask
mask = jnp.eye(5).astype(bool)
b = a.at[mask].set(99.0)
# 체이닝
b = a.at[0, 0].set(1).at[1, 1].set(2).at[2, 2].set(3)
scatter / gather 같은 좀 더 복잡한 update:
# 같은 index 에 여러 번 더하기 — 누적
indices = jnp.array([0, 0, 1, 2, 0])
values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
result = jnp.zeros(5).at[indices].add(values)
print(result) # [8.0, 3.0, 4.0, 0.0, 0.0] — index 0 에 1+2+5=8
# scatter (segment_sum)
result2 = jax.ops.segment_sum(values, indices, num_segments=5)
NumPy 패턴 → JAX 패턴 매핑:
NumPy | JAX
----------------------------|---------------------------
a[i] = v | a = a.at[i].set(v)
a[i] += v | a = a.at[i].add(v)
a[i:j] = b | a = a.at[i:j].set(b)
a[mask] = 0 | a = a.at[mask].set(0)
| (또는 jnp.where(mask, 0, a))
np.add.at(a, idx, vals) | a.at[idx].add(vals)
💡 .at vs jnp.where
boolean mask 로 update 할 때 — jnp.where(mask, new, old) 가 보통 더 깔끔하고 빠름. .at 은 정확한 index 를 알 때 (특히 scatter) 강력함. 둘 다 jit + grad + vmap 친화적이라 — 둘 중 어느 걸 써도 큰 문제 없음. 가독성 우선.
처음 .at 패턴은 "왜 이렇게 길게 써야 하지?" 싶은데 — 익으면 — NumPy 의 a[i] = v 가 실제로 얼마나 위험한 추상화였는지 보이게 돼. side effect 를 한 줄짜리 syntax 로 숨겨 놓은 거였거든. JAX 는 그 비용을 가시화한 거.