C.W.K.
Stream
Lesson 01 of 06 · published

Serialization — state_dict 에서 .pt2 까지

~12 min · serialization, save, torchscript, torch.export

Level 0Tensor 호기심
0 XP0/62 lessons0/13 achievements
0/120 XP to next level120 XP to go0% complete

세 serialization 형식, 세 use case

  • state_dict (.pth / .pt) — Python OrderedDict 의 parameter + buffer. 가장 유연. 로드 위 model 클래스 필요. training 과 research 의 default.
  • torch.export (.pt2) — exported computational graph + weight. 원본 Python 클래스 없이 loadable. deploy 의 modern path.
  • TorchScript (.pt) — 옛 graph + weight 형식. 여전히 널리 지원되지만 새 코드엔 legacy.
  • ONNX (.onnx) — cross-framework 표준. 너 serving runtime 이 PyTorch 아닐 때 옳은 선택 (ONNX Runtime, TensorRT, OpenVINO, browser 의 onnxruntime-web).

각 형식이 포기하는 것

state_dict 가 PyTorch version 들에 portable 하고 model refactor 에 resilient, 근데 model instantiate 위 Python 클래스 필요. torch.export 와 TorchScript 가 self-contained 지만 PyTorch runtime 에 묶임. ONNX 가 runtime 들에 portable 지만 framework-specific 최적화 잃음.

결정 tree

  • further training 위 공유? → state_dict.
  • PyTorch runtime 통해 deploy? → torch.export (.pt2).
  • ONNX Runtime / TensorRT / browser 통해 deploy? → ONNX (dynamo path 통해).
  • mobile 에 deploy? → ExecuTorch (나중 lesson 에 cover).
  • Apple 에 deploy? → CoreML (또한 나중 lesson).

Code

state_dict — 유연하지만 Python-class 의존·python
import torch
import torch.nn as nn

class TinyMLP(nn.Module):
    def __init__(self): super().__init__(); self.fc = nn.Linear(10, 4)
    def forward(self, x): return self.fc(x)

model = TinyMLP()
torch.save(model.state_dict(), '/tmp/tiny.pth')

# To load — must have the class definition
m2 = TinyMLP()
m2.load_state_dict(torch.load('/tmp/tiny.pth', weights_only=True))
m2.eval()
torch.export (.pt2) — self-contained, modern·python
import torch
import torch.nn as nn

class TinyMLP(nn.Module):
    def __init__(self): super().__init__(); self.fc = nn.Linear(10, 4)
    def forward(self, x): return self.fc(x)

model = TinyMLP().eval()
example = torch.randn(1, 10)

exported = torch.export.export(model, (example,))
torch.export.save(exported, '/tmp/tiny.pt2')

# Load WITHOUT the class definition
loaded = torch.export.load('/tmp/tiny.pt2')
y = loaded.module()(example)        # call .module() to get a callable
print(y.shape)                       # torch.Size([1, 4])
ONNX export — non-PyTorch runtime 위·python
import torch

class TinyMLP(torch.nn.Module):
    def __init__(self): super().__init__(); self.fc = torch.nn.Linear(10, 4)
    def forward(self, x): return self.fc(x)

model = TinyMLP().eval()
example = torch.randn(1, 10)

# In PyTorch 2.x, torch.onnx.export defaults to dynamo=True (modern path)
torch.onnx.export(
    model, (example,), '/tmp/tiny.onnx',
    input_names=['x'], output_names=['y'],
    dynamic_axes={'x': {0: 'batch'}, 'y': {0: 'batch'}},
)

# Then run with ONNX Runtime (separately installed)
# pip install onnxruntime
import onnxruntime as ort
sess = ort.InferenceSession('/tmp/tiny.onnx')
out = sess.run(None, {'x': example.numpy()})
print(out[0].shape)                  # (1, 4)

External links

Exercise

같은 TinyMLP 를 세 방법으로 serialize: state_dict, torch.export, ONNX. 각각 로드 가능하고 같은 input 에 (fp32 tolerance 안) 같은 output 생산하는지 검증. 파일 size 비교 — 작은 model 엔 대략 비슷해야.

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.