C.W.K.
Stream
Lesson 03 of 05 · published

Functional Update: Modify 에서 Return 으로

~8 min · purity, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

NumPy 의 in-place mutation 을 JAX 의 functional update 로 옮기는 패턴 — 매일 쓰게 됨.

핵심 도구가 .at[] indexer:

import jax.numpy as jnp

a = jnp.zeros((4, 4))

# .at[index].set(value) — 새 array 반환, 원본 안 변함
b = a.at[0, 0].set(1.0)
print(a)  # 여전히 0
print(b)  # (0,0)에 1

# 다른 update operation 들
a.at[0].add(5.0)        # +=
a.at[0].multiply(2.0)   # *=
a.at[0].divide(2.0)     # /=
a.at[0].power(2.0)      # **=
a.at[0].min(0.0)        # element-wise min
a.at[0].max(1.0)        # element-wise max
a.at[0].apply(jnp.sqrt) # 임의 함수 적용

모든 indexing 형태 지원:

a = jnp.zeros((5, 5))

# slice
b = a.at[1:3, :].set(jnp.ones((2, 5)))

# fancy indexing
b = a.at[[0, 2, 4]].set(1.0)

# boolean mask
mask = jnp.eye(5).astype(bool)
b = a.at[mask].set(99.0)

# 체이닝
b = a.at[0, 0].set(1).at[1, 1].set(2).at[2, 2].set(3)

scatter / gather 같은 좀 더 복잡한 update:

# 같은 index 에 여러 번 더하기 — 누적
indices = jnp.array([0, 0, 1, 2, 0])
values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
result = jnp.zeros(5).at[indices].add(values)
print(result)  # [8.0, 3.0, 4.0, 0.0, 0.0] — index 0 에 1+2+5=8

# scatter (segment_sum)
result2 = jax.ops.segment_sum(values, indices, num_segments=5)

NumPy 패턴 → JAX 패턴 매핑:

NumPy                       | JAX
----------------------------|---------------------------
a[i] = v                    | a = a.at[i].set(v)
a[i] += v                   | a = a.at[i].add(v)
a[i:j] = b                  | a = a.at[i:j].set(b)
a[mask] = 0                 | a = a.at[mask].set(0)
                            | (또는 jnp.where(mask, 0, a))
np.add.at(a, idx, vals)     | a.at[idx].add(vals)

💡 .at vs jnp.where

boolean mask 로 update 할 때 — jnp.where(mask, new, old) 가 보통 더 깔끔하고 빠름. .at 은 정확한 index 를 알 때 (특히 scatter) 강력함. 둘 다 jit + grad + vmap 친화적이라 — 둘 중 어느 걸 써도 큰 문제 없음. 가독성 우선.

처음 .at 패턴은 "왜 이렇게 길게 써야 하지?" 싶은데 — 익으면 — NumPy 의 a[i] = v 가 실제로 얼마나 위험한 추상화였는지 보이게 돼. side effect 를 한 줄짜리 syntax 로 숨겨 놓은 거였거든. JAX 는 그 비용을 가시화한 거.

Code

import jax.numpy as jnp

# IMPERATIVE (NumPy) style — won't work in JAX
# array[i] = value
# dict[key] = value
# object.attribute = value

# FUNCTIONAL (JAX) style — creates new values
# new_array = array.at[i].set(value)
# new_dict = {**dict, key: value}
# new_params = params._replace(attribute=value)  # NamedTuple
import jax.numpy as jnp

x = jnp.zeros((5, 5))

# Set values at specific positions
x = x.at[0, 0].set(1.0)
x = x.at[2, 3].set(5.0)
x = x.at[4, 4].set(1.0)

# Increment values
counts = jnp.zeros(10, dtype=jnp.int32)
data = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3])
# Build a histogram — add 1 at each index
histogram = counts.at[data].add(1)
print(histogram)  # [0 2 1 2 1 2 0 0 0 1]
# Note: duplicate indices accumulate (both 1s add, both 3s add, etc.)

# Scatter operations
indices = jnp.array([0, 2, 4])
values = jnp.array([10.0, 20.0, 30.0])
x = jnp.zeros(6)
x = x.at[indices].set(values)
print(x)  # [10.  0. 20.  0. 30.  0.]
import jax.numpy as jnp

x = jnp.array([-3.0, -1.0, 0.0, 2.0, 5.0])

# ReLU: replace negatives with 0
relu = jnp.where(x > 0, x, 0.0)
print(relu)  # [0. 0. 0. 2. 5.]

# Clamp between -1 and 1
clamped = jnp.where(x > 1, 1.0, jnp.where(x < -1, -1.0, x))
print(clamped)  # [-1. -1. 0. 1. 1.]

# Conditional based on another array
mask = jnp.array([True, False, True, False, True])
result = jnp.where(mask, x, 0.0)
print(result)  # [-3. 0. 0. 0. 5.]
import jax
import jax.numpy as jnp

# IMPERATIVE: modify accumulator in a loop
def imperative_cumsum(arr):
    result = []
    total = 0
    for x in arr:
        total += x
        result.append(total)
    return result

# FUNCTIONAL: use jax.lax.scan (functional loop)
def functional_cumsum(arr):
    def step(carry, x):
        new_carry = carry + x
        return new_carry, new_carry  # (next_carry, output)

    _, cumulative = jax.lax.scan(step, 0.0, arr)
    return cumulative

x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
print(functional_cumsum(x))  # [1. 3. 6. 10. 15.]

External links

Exercise

4×4 grid. 10 개의 random update 를 .at[i,j].set(v) 로. 같은 update 를 NumPy 식 mutating loop 으로 (copy 위에). output equality 와 timing 비교. 가독성 차이 적기.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.