NumPy 를 편하게 썼으면 다음 10 분은 복습이야. PyTorch indexing 룰:
Basic indexing — integer 와 slice. 가능하면 view 반환 (no copy).
Boolean indexing — 맞는 shape 의 bool tensor 로 index. 선택된 element 의 1-D tensor 반환 (항상 copy).
Advanced (fancy) indexing — integer tensor 로 index. 항상 copy.
Mixed — 위의 조합; PyTorch 는 NumPy 룰 따름.
가장 중요한 사실 하나: basic slicing 은 view 반환, copy 아님. slice 에 쓰면 원본에도 써. 기능이지 버그 아니야 — in-place update 를 싸게 만들어 — 근데 'tensor 하나 바꿨는데 셋이 따라 움직였다' 디버깅 스토리의 원천이기도 해.
가장 자주 쓰는 selector
x[:, 0] — 모든 row 의 첫 column.
x[..., -1] — 마지막 dim 의 마지막 element, x 의 dim 수와 무관. 아름답게 간결.
x[None, :] / x.unsqueeze(0) — batch dim 추가.
x[mask] — boolean filter.
torch.gather(x, dim, index) — index tensor 로 dim 따라 값 pick. attention 과 loss 코드에서 엄청 흔함.
Code
Basic indexing 과 slicing·python
import torch
t = torch.tensor([
[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
])
t[0] # tensor([1, 2, 3, 4]) — first row
t[0, 2] # tensor(3) — single element
t[-1] # last row
t[:, 1] # all rows, col 1 → tensor([2, 6, 10])
t[0:2, :] # first two rows
t[::2, :] # every other row
t[..., -1] # last col, regardless of rank → tensor([4, 8, 12])
import torch
a = torch.arange(12).reshape(3, 4)
row = a[1] # VIEW into a (no copy)
row[:] = 99 # writes into a too!
print(a)
# tensor([[ 0, 1, 2, 3],
# [99, 99, 99, 99],
# [ 8, 9, 10, 11]])
# To get an independent copy:
row_copy = a[1].clone()
row_copy[:] = -1
print(a[1]) # untouched: tensor([99, 99, 99, 99])
torch.gather — attention/loss 의 일꾼·python
import torch
# Pick the predicted-class probability for each sample
logits = torch.randn(4, 5).softmax(-1) # (batch=4, classes=5)
labels = torch.tensor([2, 0, 4, 1])
# We want logits[i, labels[i]] for each i
chosen = logits.gather(1, labels.unsqueeze(1)).squeeze(1)
print(chosen.shape) # torch.Size([4])
# Equivalent (less efficient) loop:
# chosen = torch.tensor([logits[i, labels[i]] for i in range(4)])
torch.randn 으로 (4, 5) logit tensor 만들기. label tensor [2, 0, 4, 1] 주어졌을 때, 각 sample 의 label class 의 logit 을 둘 다로 추출: Python loop AND torch.gather. 같은 값 내는지 확인. timeit 으로 (1024, 1000) tensor 에서 비교 — gather 가 극적으로 빨라야 함.
Progress
Progress is local-only — sign in to sync across devices.