- Published on
PyTorch INT8 양자화 후 정확도 급락 해결 7단계
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙 비용을 줄이려고 PyTorch 모델을 INT8로 양자화했는데, 지연시간은 줄었지만 정확도가 체감될 정도로 급락하는 경우가 흔합니다. 대부분은 “INT8이라서 어쩔 수 없다”가 아니라, 관측(캘리브레이션) 데이터 품질, 연산자(fusion/observer) 설정, 레이어별 민감도, 백엔드(qnnpack/fbgemm)와 배치/채널 축 선택 같은 실무 설정 문제로 발생합니다.
이 글은 원인 분리를 위한 체크리스트가 아니라, 실제로 정확도를 되살릴 가능성이 높은 순서대로 진행하는 7단계 복구 절차로 구성했습니다. 각 단계는 “무엇을 바꾸고”, “왜 효과가 있고”, “어떻게 검증하는지”에 초점을 맞춥니다.
참고: 양자화는 모델 종류(CNN/Transformer), 배포 환경(x86/ARM), 목표(속도/메모리)마다 최적점이 다릅니다. 아래 절차는 공통적으로 재현 가능한 디버깅 루틴을 제공합니다.
1단계: 기준선 재현과 평가 파이프라인 고정
정확도 급락 이슈는 대부분 평가 파이프라인이 양자화 전후로 미세하게 달라져서 더 크게 보이거나, 반대로 문제를 숨깁니다. 가장 먼저 할 일은 다음을 고정하는 것입니다.
- 전처리/후처리(정규화, 리사이즈, 토크나이저, padding) 동일
model.eval()강제- dropout, stochastic depth, batchnorm 동작 고정
- seed 고정 및 평가 데이터셋/샘플 수 고정
import torch
import random
import numpy as np
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
model_fp32.eval()
# 양자화 모델도 반드시 eval
model_int8.eval()
@torch.inference_mode()
def evaluate(model, dataloader):
correct = 0
total = 0
for x, y in dataloader:
logits = model(x)
pred = logits.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.numel()
return correct / total
acc_fp32 = evaluate(model_fp32, val_loader)
acc_int8 = evaluate(model_int8, val_loader)
print(acc_fp32, acc_int8)
이 단계에서 해야 할 핵심은 “정확도 급락”을 수치로 고정하고, 이후 단계에서 한 번에 하나의 변경만 적용해 영향도를 측정하는 것입니다.
2단계: 백엔드와 qconfig를 명시적으로 고정
PyTorch 정적 양자화는 백엔드에 따라 관측기(observer)와 양자화 방식이 달라집니다.
- x86 서버: 보통
fbgemm - ARM 모바일: 보통
qnnpack
백엔드가 달라지면 같은 모델도 정확도/속도 특성이 크게 변합니다. 또한 기본 qconfig가 모델에 맞지 않으면 정확도 급락이 쉽게 발생합니다.
import torch
# x86라면 보통 fbgemm
torch.backends.quantized.engine = "fbgemm"
from torch.ao.quantization import get_default_qconfig
qconfig = get_default_qconfig(torch.backends.quantized.engine)
print(qconfig)
추가로, 실무에서는 activation은 per-tensor affine, weight는 per-channel symmetric이 일반적으로 유리합니다. 기본 qconfig가 이를 만족하는지 확인하고, 필요하면 명시적으로 구성합니다.
from torch.ao.quantization import QConfig
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
custom_qconfig = QConfig(
activation=HistogramObserver.with_args(dtype=torch.quint8, qscheme=torch.per_tensor_affine),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
)
이 단계의 목표는 “내 환경에서 어떤 양자화가 실제로 적용되는지”를 고정해, 이후 캘리브레이션과 레이어별 튜닝이 의미 있게 만들도록 하는 것입니다.
3단계: 캘리브레이션 데이터 품질과 분포를 재점검
정적 INT8의 정확도는 캘리브레이션 데이터가 전부라고 해도 과언이 아닙니다. 흔한 실패 패턴은 다음과 같습니다.
- 캘리브레이션에 학습 데이터 일부를 대충 사용했는데, 실제 서빙 분포와 다름
- 전처리 누락(정규화/리사이즈 방식 차이)
- 샘플 수가 너무 적음(수십 장 수준)
- 특정 클래스/길이/해상도에 편향
권장 접근:
- 최소 수백~수천 샘플(모델/도메인에 따라)
- 서빙 트래픽에서 샘플링한 “현실 데이터” 포함
- 입력 길이/해상도/밝기 등 분포를 실제와 맞춤
from torch.ao.quantization import prepare
model_to_quant = model_fp32
model_to_quant.eval()
model_to_quant.qconfig = custom_qconfig
prepared = prepare(model_to_quant)
@torch.inference_mode()
def calibrate(model, dataloader, num_batches=50):
for i, (x, _) in enumerate(dataloader):
model(x)
if i + 1 >= num_batches:
break
calibrate(prepared, calib_loader, num_batches=200)
캘리브레이션은 “학습”이 아니라 “관측”입니다. 관측이 틀리면 quant scale/zero-point가 틀어지고, 그 결과가 정확도 급락으로 직결됩니다.
4단계: 관측기(Observer) 종류를 바꿔 이상치(outlier) 대응
정확도 급락의 대표 원인은 activation 분포의 긴 꼬리(outlier) 입니다. MinMax 기반 관측기는 이상치 하나에 스케일이 끌려가서, 대부분의 값이 양자화 그리드에서 뭉개집니다.
대응책은 histogram 기반(또는 percentile/kl-divergence 계열) 관측기를 쓰거나, 레이어별로 관측기를 다르게 적용하는 것입니다.
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization import QConfig
qconfig_hist = QConfig(
activation=HistogramObserver.with_args(
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False
),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
)
실무 팁:
- 첫 시도는 activation에
HistogramObserver적용 - 특정 블록에서만 급락하면 해당 블록만 별도 qconfig 적용(모듈별 qconfig override)
이 단계는 “스케일이 이상치에 끌려가는 문제”를 완화해 체감 정확도를 크게 복구하는 경우가 많습니다.
5단계: 레이어별 민감도 분석 후 선택적 FP16/FP32 유지(혼합 정밀도)
모든 레이어를 INT8로 만들 필요는 없습니다. 특히 아래는 민감도가 높아 정확도 급락을 유발하기 쉽습니다.
- 첫 Conv/Stem
- 마지막 분류기(FC)
- LayerNorm/Softmax 주변(Transformer 계열)
- residual 합산 직전/직후
가장 현실적인 접근은 민감한 구간은 FP32로 남기고, 나머지만 INT8로 내려 속도/메모리 이득을 가져가는 것입니다.
정적 양자화에서 완전 자동으로 해결이 안 되면, 해당 모듈의 qconfig = None으로 제외하는 전략이 효과적입니다.
import torch.nn as nn
from torch.ao.quantization import prepare, convert
model = model_fp32
model.eval()
model.qconfig = qconfig_hist
# 예: 마지막 분류기는 양자화 제외
if hasattr(model, "fc") and isinstance(model.fc, nn.Linear):
model.fc.qconfig = None
prepared = prepare(model)
calibrate(prepared, calib_loader, num_batches=200)
model_int8 = convert(prepared)
민감도 분석을 더 체계적으로 하려면, 블록 단위로 하나씩 제외하면서 정확도 변화를 기록해 “정확도에 가장 큰 영향을 주는 블록”을 찾으면 됩니다.
6단계: QAT(Quantization Aware Training)로 미세 보정
캘리브레이션과 관측기 튜닝만으로 복구가 안 되면, 모델이 INT8의 양자화 노이즈를 견디도록 QAT로 미세 조정하는 것이 정석입니다. 특히 Transformer나 작은 모델(표현력이 낮은 모델)은 PTQ만으로 버티기 어렵습니다.
QAT 핵심 포인트:
- 학습률을 낮게, 짧게(수 에폭) 시작
- 대표 데이터 분포로 학습
- 배치정규화(freeze)나 EMA 등 안정화 고려
from torch.ao.quantization import get_default_qat_qconfig, prepare_qat
torch.backends.quantized.engine = "fbgemm"
model = model_fp32
model.train()
model.qconfig = get_default_qat_qconfig("fbgemm")
qat_model = prepare_qat(model)
optimizer = torch.optim.AdamW(qat_model.parameters(), lr=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(2):
for x, y in train_loader:
optimizer.zero_grad()
logits = qat_model(x)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
qat_model.eval()
from torch.ao.quantization import convert
int8_model = convert(qat_model)
QAT는 비용이 들지만, “INT8에서만 발생하는 오차”를 학습이 흡수하게 만들기 때문에 정확도 복구에 가장 확실한 카드입니다.
7단계: 디버깅을 자동화하고, 실패를 재발 방지용 체크로 고정
정확도 급락은 한 번 해결해도 모델/데이터/전처리/백엔드가 바뀌면 재발합니다. 따라서 아래를 자동화해두면 운영 효율이 올라갑니다.
- FP32 vs INT8 정확도 차이(절대값, 상대값) CI 체크
- 캘리브레이션 샘플 수/분포 검사
- 레이어별 스케일 통계 로그(극단값 감지)
- 백엔드(engine) 및 qconfig 해시 기록
간단한 “정확도 게이트” 예시는 다음과 같습니다.
def assert_quant_quality(acc_fp32, acc_int8, max_drop=0.01):
drop = acc_fp32 - acc_int8
if drop > max_drop:
raise RuntimeError(f"INT8 accuracy drop too large: {drop:.4f}")
assert_quant_quality(acc_fp32, acc_int8, max_drop=0.02)
또한 양자화 파이프라인은 작은 설정 차이로 결과가 크게 바뀌므로, 운영 디버깅 관점에서 “원인 추적 루틴”을 문서화해두는 게 중요합니다. 장애 원인을 단계적으로 좁히는 방식은 인프라/서빙 영역에서도 동일하게 유효합니다. 예를 들어 타임아웃 급증을 체크포인트로 분해해 추적하는 접근은 AWS ALB 502/504 급증 - 타임아웃 7곳 점검 같은 글의 문제 해결 방식과 결이 같습니다.
자주 터지는 원인 요약(체크리스트)
아래 중 하나라도 해당하면, 위 7단계를 순서대로 적용했을 때 복구될 확률이 큽니다.
- 캘리브레이션 데이터가 너무 적거나 서빙 분포와 다름
- activation outlier가 심한데 MinMax 관측기를 사용
- 백엔드가 바뀌었는데 qconfig를 고정하지 않음
- 레이어별 민감도를 무시하고 전부 INT8로 변환
- PTQ로 안 되는데도 QAT를 시도하지 않음
마무리: “PTQ 튜닝”으로 안 되면 빨리 QAT로 넘어가기
정확도 급락을 해결하는 가장 빠른 길은, 막연히 설정을 바꾸는 게 아니라 재현 가능한 기준선을 만들고, 캘리브레이션 분포와 observer/outlier를 먼저 잡고, 마지막에 선택적 양자화 또는 QAT로 확정 짓는 것입니다.
추가로, 양자화된 모델을 RAG 파이프라인이나 검색/재랭킹에 붙이는 경우에는 모델 정확도 저하가 곧바로 답변 품질 저하(환각 증가)로 이어질 수 있습니다. 그런 경우 검색 품질 측면의 안전장치도 함께 설계하는 것이 좋습니다. 관련해서는 RAG 환각 줄이기 - ColBERTv2+Rerank 최적화 글이 참고가 됩니다.
원하시면 모델 종류(CNN, ViT, BERT 계열), 배포 타깃(x86, ARM), 현재 사용 중인 양자화 방식(동적, 정적, QAT), 그리고 FP32 대비 정확도 하락 폭을 알려주시면, 위 7단계를 바탕으로 “어느 단계부터 시작해야 가장 빠른지”를 케이스별로 좁혀드릴 수 있습니다.