C.W.K.
Stream
Lesson 04 of 07 · published

Custom Collate 함수와 IterableDataset

~12 min · collate, iterable, stream, padding

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

sample 이 깔끔히 stack 안 될 때

default collate 함수가 각 sample 의 모든 field 를 tensor 로 stack. 모든 sample 이 같은 shape 일 때 작동 — image batch, fixed-feature regression. variable-length sequence, per-item metadata 운반, 또는 custom batch 원할 때 깨짐.

Custom collate

collate 함수가 List[Sample] 받고 batch 반환. 가장 흔한 variant: variable-length sequence 를 batch 에서 가장 긴 거로 pad.

IterableDataset — streaming data

disk 에 fit 하기엔 너무 큰 dataset (거대 text corpus, network-streamed data, online sensor feed) 위, Dataset 대신 IterableDataset 구현. __iter__(self) 만 정의; PyTorch 가 너 iterator 에서 batch iterate.

multi-worker IterableDataset 의 catch: 수동 split 안 하면 각 worker 가 full iterator 받음. __iter__ 안에 torch.utils.data.get_worker_info() 로 worker id 별 stream slice.

Code

variable-length sequence 위 pad-to-longest collate·python
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

class VarSeqDataset(Dataset):
    def __init__(self):
        self.seqs = [torch.randint(1, 100, (torch.randint(5, 25, (1,)).item(),)) for _ in range(64)]
        self.labels = torch.randint(0, 2, (64,))
    def __len__(self): return len(self.seqs)
    def __getitem__(self, i): return self.seqs[i], self.labels[i]

def collate_pad(batch):
    seqs, labels = zip(*batch)
    padded = pad_sequence(seqs, batch_first=True, padding_value=0)
    lengths = torch.tensor([len(s) for s in seqs])
    return padded, lengths, torch.tensor(labels)

loader = DataLoader(VarSeqDataset(), batch_size=8, collate_fn=collate_pad)
padded, lengths, labels = next(iter(loader))
print(padded.shape, lengths)   # torch.Size([8, max_len_in_batch]) tensor([...])
per-sample metadata 반환 — dict-shaped batch·python
import torch
from torch.utils.data import DataLoader, Dataset

class TaggedDataset(Dataset):
    def __init__(self):
        self.X = torch.randn(32, 10)
        self.y = torch.randint(0, 2, (32,))
        self.tags = [f"sample_{i}" for i in range(32)]
    def __len__(self): return 32
    def __getitem__(self, i):
        return {'x': self.X[i], 'y': self.y[i], 'tag': self.tags[i]}

def collate_dict(batch):
    return {
        'x': torch.stack([b['x'] for b in batch]),
        'y': torch.stack([b['y'] for b in batch]),
        'tags': [b['tag'] for b in batch],   # keep as list
    }

loader = DataLoader(TaggedDataset(), batch_size=8, collate_fn=collate_dict)
batch = next(iter(loader))
print(batch['x'].shape, batch['y'].shape, batch['tags'])
적절한 worker splitting 의 IterableDataset·python
import torch
from torch.utils.data import IterableDataset, DataLoader

class StreamingDataset(IterableDataset):
    """Stream samples from a generator. Splits across workers correctly."""
    def __init__(self, n_total=10_000):
        self.n_total = n_total

    def __iter__(self):
        info = torch.utils.data.get_worker_info()
        if info is None:                    # single-process
            start, end = 0, self.n_total
        else:
            per = self.n_total // info.num_workers
            start = info.id * per
            end = self.n_total if info.id == info.num_workers - 1 else start + per

        for i in range(start, end):
            yield torch.randn(10), i % 2

loader = DataLoader(StreamingDataset(), batch_size=32, num_workers=4)
print(sum(b[0].size(0) for b in loader))   # ≈ 10000

External links

Exercise

(image_tensor, list_of_bbox_tensors_per_image, image_id_string) 반환하는 Dataset 구현. image 별 bounding box 수 다양. stacked image batch, bbox tensor 의 Python list, image ID 의 Python list 반환하는 collate 짜기. 몇 sample size 로 검증.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.