C.W.K.
Stream
Lesson 03 of 05 · published

Quantization-aware training과 pruning

~12 min · qat, pruning, model-optimization

Level 0Level 0
0 XP0/78 lessons0/17 achievements
0/100 XP to next level100 XP to go0% complete

Post-training quantization 정확도가 부족할 때

Post-training quantization은 편하지만 민감한 task에서 정확도 떨어뜨릴 수 있어. Quantization-Aware Training (QAT)은 training의 forward pass에서 int8 quantization 시뮬레이션 — model은 float32 유지하지만 int8 변환에 robust한 weight 학습. 마지막 TFLite 변환이 실제 int8 model 만들어.

Post-training quantization 대비 일반적 정확도 이득: 절대값 1–3%. 정확도 예산 끝에 있을 때 유의미.

Magnitude-based pruning은 작은 weight를 0으로 만들어 sparse model 생성. Quantization과 결합하면 정확도 손실 최소로 ~10× 크기 축소. TF Model Optimization Toolkit (tensorflow_model_optimization)이 training 중 점진적으로 sparsity 늘리는 Keras layer wrapper 제공.

Code

QAT workflow·python
import tensorflow as tf
import tensorflow_model_optimization as tfmot

# pip install tensorflow-model-optimization

# 1. Start with a pretrained float32 model
base_model = tf.keras.applications.MobileNetV2(
    weights='imagenet', input_shape=(224, 224, 3),
)
# ... assume base_model is fine-tuned on your task ...

# 2. Wrap with QAT — every layer gets fake quant ops
q_aware = tfmot.quantization.keras.quantize_model(base_model)

# 3. Recompile (wrapper changes layer structure)
q_aware.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'],
)

# 4. Fine-tune for a few epochs — learns quantization-robust weights
q_aware.fit(train_data, train_labels, batch_size=64, epochs=5,
            validation_split=0.1)

# 5. Convert to actual int8 TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_qat = converter.convert()

with open('model_qat.tflite', 'wb') as f:
    f.write(tflite_qat)
Magnitude pruning·python
import tensorflow as tf
import tensorflow_model_optimization as tfmot
import numpy as np

prune = tfmot.sparsity.keras.prune_low_magnitude

# Schedule: ramp 50% → 80% sparsity over training
end_step = np.ceil(num_images / batch_size).astype(int) * epochs
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.50,
        final_sparsity=0.80,
        begin_step=0,
        end_step=end_step,
    )
}

pruned = prune(model, **pruning_params)
pruned.compile(optimizer='adam',
               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])

# REQUIRED callback — without it, sparsity doesn't update
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir='/tmp/pruning'),
]
pruned.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,
           callbacks=callbacks)

# Strip wrappers before export
final = tfmot.sparsity.keras.strip_pruning(pruned)
converter = tf.lite.TFLiteConverter.from_keras_model(final)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pruned_quant_tflite = converter.convert()
# Result: typically ~10× smaller than original float32 model

Progress

Progress is local-only — sign in to sync across devices.
이 페이지에서 버그를 발견하셨거나 피드백이 있으세요?문제 신고

댓글 0

🔔 답글 알림 (로그인 필요)
로그인댓글을 남기려면 로그인해 주세요.

아직 댓글이 없어요. 첫 댓글을 남겨보세요.