C.W.K.
Stream
Lesson 01 of 08 · published

Dataset 객체

~18 min · dataset, pytorch, data

Level 0Curious
0 XP0/73 lessons0/11 achievements
0/120 XP to next level120 XP to go0% complete

Dataset abstraction

PyTorch torch.utils.data.Dataset 은 'index 해서 한 example 가져올 수 있는 것' 의 canonical contract. Subclass 하고 __len____getitem__ 구현하면, 나머지 data pipeline (loader, sampler, augmentation) 이 위에서 그냥 작동.

Dataset 은 example 당 transformation 이 사는 곳: disk 에서 row 읽기, image decode, string tokenize, tensor normalize. Example index 에 의존하는 모든 게 여기. Batch 에 의존하는 (collation, max length 까지 padding) 건 DataLoader 의 collate function.

팁: Dataset 은 default 로 lazy — __getitem__ 은 누가 example 요청할 때만 돌아. Terabyte-scale data 를 RAM 에 다 안 올리고 train 할 수 있는 이유야.

Map-style vs iterable-style

Map-style (가장 흔함): Dataset subclass, __getitem__(idx) 구현. Random access, shuffling 지원, length 알려짐. Finite, indexable dataset 에 best.

Iterable-style: IterableDataset subclass, __iter__ 구현. Sequential streaming, random access 없음, length unknown 가능. Log stream, distributed sharding, 거대 web-scale corpus 에 best.

눈 가리고도 쓸 수 있어야 할 Dataset 3 개

  1. Tensor wrapper — in-memory data 에 TensorDataset(X, y). Prototyping 에 OK.
  2. Image folder — class-per-subfolder image dataset 에 torchvision.datasets.ImageFolder('path/').
  3. Custom CSV / JSONL__init__ 에 manifest 읽고 __getitem__ 에 lazy-load.
원칙: 대부분 data bug 가 __getitem__ 안에 숨어. Training 시작 전에 example 하나 print. Shape, dtype, label 이 본인 model 기대와 맞는지 확인.

Code

A custom Dataset for an image manifest·python
import torch
from torch.utils.data import Dataset
from PIL import Image
import csv, pathlib

class ManifestImageDataset(Dataset):
    def __init__(self, manifest_path, root, transform=None):
        self.root = pathlib.Path(root)
        self.transform = transform
        with open(manifest_path) as f:
            self.rows = list(csv.DictReader(f))  # [{'filename': ..., 'label': ...}]

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        row = self.rows[idx]
        img = Image.open(self.root / row["filename"]).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        label = int(row["label"])
        return img, label

External links

Exercise

본인이 가진 tabular CSV 에 custom Dataset 작성. Index 해서 첫 3 example print, shape 와 dtype 확인. Doc 안 보고 이거 가능한 첫 순간이 PyTorch data layer 를 'own' 하는 첫 순간.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.