C.W.K.
Stream
Lesson 06 of 08 · published

Convolution Layer — Conv2d 와 친구들

~14 min · conv2d, cnn, channels, padding

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

PyTorch 의 NCHW universe 에서 conv layer

PyTorch 는 image data 를 (N, C, H, W) 로: batch, channel, height, width. (다른 주요 convention 은 NHWC, TensorFlow 와 CoreML 사용 — framework 넘을 때 permute 준비.)

nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0)(kernel_size, kernel_size, in_channels) shape 의 out_channels filter 학습, 각각 input 위 slide, out_channels feature map 생산. output spatial size 표준 공식: out = (in + 2*padding - kernel) / stride + 1.

실제 쓰게 될 variant

  • Conv2d — 2D conv. image default.
  • Conv1d — 1D conv. sequence (audio waveform, character-level text, time series) 에 유용.
  • ConvTranspose2d — upsampling 위 'deconvolution'. U-Net 의 decoder, DCGAN 의 generator.
  • Depthwise separable convConv2d(groups=in_channels) + Conv2d(1x1) 결합으로 짓기. MobileNet / EfficientNet 을 효율적으로 만드는 trick.

padding shorthand

PyTorch 1.10+ 에 padding='same' 을 string 으로 추가, stride=1 에 spatial dim 유지. 모두가 손으로 쓰던 'padding 직접 계산' 의 편한 버전. spatial dim 떨어뜨릴 specific 이유 없으면 사용.

Code

Conv2d 기본·python
import torch
import torch.nn as nn

# 3 RGB channels in, 16 feature maps out, 3x3 kernel
conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3,
                 stride=1, padding=1)

x = torch.randn(8, 3, 32, 32)   # batch=8, RGB, 32x32 image
y = conv(x)
print(y.shape)                   # torch.Size([8, 16, 32, 32]) — same spatial
print(conv.weight.shape)         # torch.Size([16, 3, 3, 3])  — out, in, kH, kW
print(conv.bias.shape)           # torch.Size([16])

# stride=2 halves spatial dims
conv_down = nn.Conv2d(3, 16, 3, stride=2, padding=1)
print(conv_down(x).shape)        # torch.Size([8, 16, 16, 16])
간단 CNN — canonical pattern·python
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),                 # 32x32 → 16x16

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),                 # 16x16 → 8x8

            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),         # global avg pool → 1x1
        )
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.flatten(1)                     # (B, 128, 1, 1) → (B, 128)
        return self.classifier(x)

model = SimpleCNN()
print(model(torch.randn(4, 3, 32, 32)).shape)   # torch.Size([4, 10])
Depthwise separable conv — MobileNet trick·python
import torch.nn as nn

class DepthwiseSeparable(nn.Module):
    """Replace a regular conv with depthwise + pointwise — far fewer params."""
    def __init__(self, in_ch, out_ch, kernel=3):
        super().__init__()
        # Depthwise: each input channel gets its own kernel
        self.depthwise = nn.Conv2d(in_ch, in_ch, kernel,
                                    padding=kernel // 2, groups=in_ch)
        # Pointwise: 1x1 conv to mix channels
        self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1)

    def forward(self, x):
        return self.pointwise(self.depthwise(x))

# Compare param counts
regular = nn.Conv2d(64, 128, 3, padding=1)
sep = DepthwiseSeparable(64, 128, 3)

regular_params = sum(p.numel() for p in regular.parameters())
sep_params = sum(p.numel() for p in sep.parameters())
print(f"Regular: {regular_params:,}")    # 73,856
print(f"Separable: {sep_params:,}")       # 8,896 — about 8x fewer

External links

Exercise

작은 ResNet-style block 짓기: Conv2d(64, 64, 3, padding=1) → BatchNorm → ReLU → Conv2d(64, 64, 3, padding=1) → BatchNorm, 그 다음 input 다시 더하기 (skip connection) 하고 final ReLU 전. spatial dim 과 channel 수 보존 검증.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.