C.W.K.
Stream
Lesson 05 of 08 · published

Dataset, DataLoader, Worker 배관

~14 min · dataset, dataloader, num_workers, collate

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

GPU 잘 먹이는 두 abstraction

PyTorch 가 'sample 이 뭐냐' 와 '효율적으로 batch 어떻게' 분리:

  • Dataset__len____getitem__(idx) 정의. 최소 계약: sample 몇 개, i 번째 어떻게.
  • DataLoader — Dataset wrap, batching, shuffling, worker process 통한 parallel loading, fast GPU transfer 위 pin-memory 추가.

중요한 DataLoader knob

  • batch_size — 명백.
  • shuffle=True — 매 epoch 순서 random (training 전용; val/test 절대 X).
  • num_workers — parallel loading 위 subprocess 수. I/O-bound dataset 엔 N (CPU core 수) 설정. 0 은 main-process loading (가장 느리지만 디버깅 쉬움).
  • pin_memory=True — fast CPU→GPU transfer 위 page-locked memory 에 batch 할당. x.to(device, non_blocking=True) 와 결합 시 실재 speedup.
  • prefetch_factor — 각 worker 가 미리 prefetch 하는 batch 수. default 2; GPU 가 다음 batch 도착 전에 step 끝내면 올림.
  • persistent_workers=True — epoch 들 사이 worker process 살아있게. 매 epoch worker 재가동 비용 회피.
  • drop_last=True — 마지막 불완전 batch drop. 일관된 batch size 원할 때 유용 (BatchNorm stats, distributed training).

macOS num_workers 함정

macOS 에선 multiprocessing default 가 'spawn' (Linux 의 'fork' 아님). 각 worker 가 module 재 import — heavy dataset 시작 느릴 수 있고, 깔끔히 pickle 안 되는 객체는 에러. workaround: dev 엔 num_workers=0, production run 엔 올림.

Code

Custom Dataset — 최소 계약·python
import torch
from torch.utils.data import Dataset, DataLoader

class TensorDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        if self.transform is not None:
            x = self.transform(x)
        return x, self.y[idx]

X = torch.randn(1000, 10)
y = torch.randint(0, 5, (1000,))
ds = TensorDataset(X, y)
print(len(ds), ds[0][0].shape, ds[0][1])
Production-shaped DataLoader·python
import torch
from torch.utils.data import DataLoader

loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=8,                # match CPU cores
    pin_memory=True,              # fast CPU→GPU
    prefetch_factor=4,            # batches each worker prefetches ahead
    persistent_workers=True,      # keep workers alive across epochs
    drop_last=True,               # consistent batch shape
)

device = "cuda"
for x, y in loader:
    x = x.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)
    # ...training step...
random_split — 한 Dataset 에서 train / val·python
import torch
from torch.utils.data import random_split, DataLoader

dataset = TensorDataset(torch.randn(1000, 10), torch.randint(0, 5, (1000,)))
n_train = int(0.8 * len(dataset))
n_val = len(dataset) - n_train

train_ds, val_ds = random_split(
    dataset, [n_train, n_val],
    generator=torch.Generator().manual_seed(42),   # reproducible split
)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)
Custom collate — variable-length sequence·python
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

class VariableSeqDataset(Dataset):
    def __init__(self):
        self.seqs = [torch.randint(0, 100, (torch.randint(5, 20, (1,)).item(),)) for _ in range(64)]
        self.labels = torch.randint(0, 2, (64,))

    def __len__(self): return len(self.seqs)
    def __getitem__(self, i): return self.seqs[i], self.labels[i]

def collate(batch):
    seqs, labels = zip(*batch)
    padded = pad_sequence(seqs, batch_first=True, padding_value=0)
    lengths = torch.tensor([len(s) for s in seqs])
    return padded, lengths, torch.tensor(labels)

loader = DataLoader(VariableSeqDataset(), batch_size=8, collate_fn=collate)
batch = next(iter(loader))
print(batch[0].shape, batch[1], batch[2].shape)
# torch.Size([8, max_len])  tensor([12, 18, ...])  torch.Size([8])

External links

Exercise

directory tree (root/class_name/image.jpg) 에서 image 로드하는 Dataset 짓기. __len__ 와 __getitem__ 구현. num_workers=4 와 pin_memory=True 의 DataLoader 로 wrap. 머신에서 sample/sec throughput 시간 — 나중에 slowdown 발견에 유용한 baseline.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.