Tensor 의 모양을 바꾸는 건 — 모든 ML 코드의 5~10% 를 차지해. 익숙해져야 함. JAX reshape 는 NumPy 와 동일.
import jax.numpy as jnp
a = jnp.arange(24)
b = a.reshape(2, 3, 4)
c = a.reshape(-1, 4) # -1 = "알아서 계산"
d = a.reshape(4, 6).T # transpose
e = jnp.expand_dims(a, axis=0)
f = jnp.squeeze(e)
import jax.numpy as jnp
# 2D transpose
a = jnp.array([[1, 2, 3], [4, 5, 6]])
print(a.T.shape) # (3, 2)
# Higher-dimensional: permute axes
# Common in ML: converting between channels-first and channels-last
img = jnp.zeros((32, 3, 224, 224)) # (batch, channels, height, width)
# NCHW -> NHWC
img_nhwc = jnp.transpose(img, (0, 2, 3, 1))
print(img_nhwc.shape) # (32, 224, 224, 3)
import jax.numpy as jnp
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
# Concatenate: join along existing axis
c = jnp.concatenate([a, b])
print(c) # [1 2 3 4 5 6]
# Stack: join along a NEW axis
s = jnp.stack([a, b])
print(s) # [[1 2 3], [4 5 6]]
print(s.shape) # (2, 3)
# vstack and hstack
v = jnp.vstack([a, b]) # Same as stack for 1D → 2D
h = jnp.hstack([a, b]) # Same as concatenate for 1D
(32, 28, 28, 3) image batch. (32, 28*28*3) 으로 flatten. 다시 unflatten. channels-first 로 transpose. 매 단계 새 shape 와 stride 출력. layout reshape 연습이 NN 코드의 silent bug 예방의 가장 싼 방법.
Progress
Progress is local-only — sign in to sync across devices.