C.W.K.
Stream
Lesson 04 of 06 · published

Reshaping 과 Broadcasting

~8 min · numpy, jax, tutorial

Level 0호기심
0 XP0/73 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

Tensor 의 모양을 바꾸는 건 — 모든 ML 코드의 5~10% 를 차지해. 익숙해져야 함. JAX reshape 는 NumPy 와 동일.

import jax.numpy as jnp

a = jnp.arange(24)
b = a.reshape(2, 3, 4)
c = a.reshape(-1, 4)         # -1 = "알아서 계산"
d = a.reshape(4, 6).T        # transpose
e = jnp.expand_dims(a, axis=0)
f = jnp.squeeze(e)

image batch 의 표준 패턴:

imgs = jnp.zeros((32, 28, 28, 3))         # NHWC
imgs_chw = imgs.transpose(0, 3, 1, 2)     # NCHW
flat = imgs.reshape(32, -1)               # for dense layer

📐 shape 사고법

JAX 코드 짤 때 — 매 줄마다 머릿속으로 shape 을 추적해. print(x.shape) 디버깅이 친구야.

Code

import jax.numpy as jnp

x = jnp.arange(12)  # [0, 1, 2, ..., 11]

# Reshape: change shape without changing data
a = jnp.reshape(x, (3, 4))   # 3 rows, 4 columns
b = x.reshape(3, 4)           # Method syntax works too
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]]

# -1 infers one dimension
c = x.reshape(2, -1)  # (2, 6) — JAX figures out 6
d = x.reshape(-1, 3)  # (4, 3)
import jax.numpy as jnp

v = jnp.array([1.0, 2.0, 3.0])  # shape: (3,)

# expand_dims: add a dimension
row = jnp.expand_dims(v, axis=0)   # shape: (1, 3)
col = jnp.expand_dims(v, axis=1)   # shape: (3, 1)

# Equivalent using None/newaxis indexing
row2 = v[None, :]   # shape: (1, 3)
col2 = v[:, None]   # shape: (3, 1)

# squeeze: remove dimensions of size 1
x = jnp.zeros((1, 3, 1, 4))
squeezed = jnp.squeeze(x)          # shape: (3, 4)
partial = jnp.squeeze(x, axis=0)   # shape: (3, 1, 4)
import jax.numpy as jnp

# 2D transpose
a = jnp.array([[1, 2, 3], [4, 5, 6]])
print(a.T.shape)  # (3, 2)

# Higher-dimensional: permute axes
# Common in ML: converting between channels-first and channels-last
img = jnp.zeros((32, 3, 224, 224))  # (batch, channels, height, width)

# NCHW -> NHWC
img_nhwc = jnp.transpose(img, (0, 2, 3, 1))
print(img_nhwc.shape)  # (32, 224, 224, 3)
import jax.numpy as jnp

a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])

# Concatenate: join along existing axis
c = jnp.concatenate([a, b])
print(c)  # [1 2 3 4 5 6]

# Stack: join along a NEW axis
s = jnp.stack([a, b])
print(s)        # [[1 2 3], [4 5 6]]
print(s.shape)  # (2, 3)

# vstack and hstack
v = jnp.vstack([a, b])  # Same as stack for 1D → 2D
h = jnp.hstack([a, b])  # Same as concatenate for 1D

External links

Exercise

(32, 28, 28, 3) image batch. (32, 28*28*3) 으로 flatten. 다시 unflatten. channels-first 로 transpose. 매 단계 새 shape 와 stride 출력. layout reshape 연습이 NN 코드의 silent bug 예방의 가장 싼 방법.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.