C.W.K.
Stream
Lesson 06 of 07 · published

Imbalanced Data — Sampler, Class Weight, Focal Loss

~12 min · imbalanced, sampler, class-weight, focal-loss

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

세 직교 전략, 자주 결합

대부분 real-world classification 문제가 skewed class distribution: 99% normal transaction, 1% fraud. naive train 하면 model 이 'always predict normal' 학습 — 완벽 accuracy, 0 utility. 세 도구, 점점 공격적:

  1. Weighted sampling — DataLoader 가 어떤 sample pick 하는지 변경. rare class oversample; model 이 더 자주 봐.
  2. Loss 의 class weight — sampling 비례 유지하지만 rare-class 실수 더 count. nn.CrossEntropyLoss(weight=...).
  3. Focal loss — extreme imbalance (1:1000+) 위, easy example down-weight 하게 loss 수정. RetinaNet 원조.

자유롭게 결합. class weight 시작 (가장 쉬움, hyperparameter 가장 적음); 부족하면 weighted sampling 추가; 진짜 extreme case 위 focal loss 잡기.

피할 함정

rare class oversample AND class weight 적용하면 double-weight. 하나 고르거나 둘 다 신중히 tune. focal loss 도 같은 logic — α parameter 자체가 class weight.

Code

WeightedRandomSampler — rare class oversample·python
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler

# Suppose 1000 class-A samples, 100 class-B samples, 50 class-C samples
labels = torch.cat([
    torch.zeros(1000, dtype=torch.long),
    torch.ones(100, dtype=torch.long),
    torch.full((50,), 2, dtype=torch.long),
])

class_counts = torch.tensor([1000., 100., 50.])
weights_per_class = 1.0 / class_counts
sample_weights = weights_per_class[labels]

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True,                  # required for over-sampling rare classes
)

loader = DataLoader(dataset, batch_size=32, sampler=sampler)
# Each batch now sees roughly equal counts of A, B, C
Class-weighted CrossEntropyLoss·python
import torch
import torch.nn as nn

# Inverse-frequency weighting: rare classes get bigger weight
class_counts = torch.tensor([1000., 100., 50.])
weights = 1.0 / class_counts
weights = weights / weights.sum() * len(class_counts)  # normalize to mean 1

criterion = nn.CrossEntropyLoss(weight=weights.to(device))

# Same training loop as usual; the loss value is now class-aware
Focal loss — extreme imbalance·python
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    """Down-weights well-classified (easy) examples."""
    def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        ce = F.cross_entropy(logits, targets, reduction='none')
        pt = torch.exp(-ce)            # prob of correct class
        focal = self.alpha * (1 - pt) ** self.gamma * ce
        if self.reduction == 'mean':
            return focal.mean()
        if self.reduction == 'sum':
            return focal.sum()
        return focal

# Use just like CrossEntropyLoss
criterion = FocalLoss(alpha=0.25, gamma=2.0)

External links

Exercise

3-class imbalanced dataset (1000 / 100 / 10 sample) 짓기. classifier train: (a) rebalancing 없이, (b) class weight, (c) WeightedRandomSampler, (d) focal loss. 각 confusion matrix print. naive run 이 가장 작은 class 에 ~0% recall 보임; 다른 거들이 구해야.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.