- Published on
PyTorch 2.x PTQ int8 정확도 급락 해결 가이드
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙 비용을 줄이기 위해 FP32 모델을 int8로 내리는 PTQ(Post-Training Quantization)를 적용했는데, PyTorch 2.x로 넘어오면서 정확도가 갑자기 크게 떨어지는 경우가 있습니다. 특히 torch.ao.quantization 계열 API와 백엔드(예: fbgemm, qnnpack) 및 관측기(observer) 기본값이 얽히면서, “양자화는 됐는데 성능이 망가진” 상태가 쉽게 만들어집니다.
이 글은 PyTorch 2.x 환경에서 PTQ int8 정확도 급락을 유발하는 원인을 체크리스트로 정리하고, 재현 가능한 코드로 해결 전략을 제시합니다. 목표는 단순히 “양자화 성공”이 아니라, 정확도 하락을 예측 가능하게 줄이고 운영에서 안전하게 굴릴 수 있는 설정을 확보하는 것입니다.
PyTorch 2.x PTQ에서 정확도가 급락하는 대표 패턴
정확도 급락은 대부분 아래 중 하나(혹은 조합)입니다.
- 캘리브레이션 데이터가 부족하거나 분포가 다름
- 관측기(observer) 설정이 모델/데이터에 부적합
- 연산자 패턴이 제대로 fusion 되지 않아 양자화 경로가 깨짐
- 레이어별로 per-tensor 스케일이 과도하게 거칠어짐(특히 Conv/Linear weight)
- 활성값(activation) 범위가 outlier에 의해 찌그러짐
- 백엔드 불일치(서버는
fbgemm, 모바일은qnnpack) 또는 지원 연산자 미스매치 - 평가 시 preprocessing / postprocessing 불일치
이 중 1~5는 “정확도 급락”의 80%를 차지합니다.
먼저 확인할 것: 백엔드와 dtype 경로가 맞는가
PyTorch PTQ는 “어떤 엔진에서 어떤 quantized op로 실행되는지”가 중요합니다. 서버 CPU는 보통 fbgemm, ARM 계열은 qnnpack입니다.
import torch
print(torch.__version__)
print(torch.backends.quantized.supported_engines)
# 서버 CPU라면 보통 fbgemm
torch.backends.quantized.engine = "fbgemm"
print("engine:", torch.backends.quantized.engine)
- 엔진을 바꾸면 관측기 기본 동작과 성능/정확도에 영향이 있습니다.
- 동일 모델이라도 엔진이 바뀌면 미세하게 결과가 달라질 수 있습니다.
PTQ 파이프라인 정석: 준비-캘리브레이션-변환
PyTorch 2.x에서 Eager Mode PTQ의 기본 흐름은 아래입니다.
prepare또는prepare_fx로 관측기 삽입- 캘리브레이션 데이터로 몇 배치 inference 수행(학습 아님)
convert또는convert_fx로 int8 모듈로 변환
정확도 급락은 대개 2번(캘리브레이션)과 1번(관측기 설정)에서 발생합니다.
Eager PTQ 예제: Linear 모델에 대한 안정적인 기본 템플릿
아래는 최소한의 실수로 PTQ를 수행하는 템플릿입니다. 실제 모델에서는 fusion 가능한 패턴(Conv+BN+ReLU 등)이 더 중요하지만, 구조는 동일합니다.
import torch
import torch.nn as nn
import torch.ao.quantization as tq
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(512, 512)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
model_fp32 = MLP().eval()
# 1) 백엔드 지정
torch.backends.quantized.engine = "fbgemm"
# 2) qconfig 설정 (기본값은 무난하지만, 급락하면 여기부터 바꿈)
model_fp32.qconfig = tq.get_default_qconfig(torch.backends.quantized.engine)
# 3) 관측기 삽입
model_prepared = tq.prepare(model_fp32, inplace=False)
# 4) 캘리브레이션 (대표 분포의 데이터로 충분히)
with torch.inference_mode():
for _ in range(200):
x = torch.randn(32, 512) # 실제로는 validation에서 샘플링
model_prepared(x)
# 5) int8 변환
model_int8 = tq.convert(model_prepared, inplace=False)
# sanity check
with torch.inference_mode():
y = model_int8(torch.randn(1, 512))
print(y.shape)
이 템플릿에서 “정확도 급락”이 나면, 다음 섹션의 체크리스트를 순서대로 적용하세요.
체크리스트 1: 캘리브레이션 데이터 품질과 개수
PTQ는 학습이 아니라 통계(최솟값/최댓값 혹은 히스토그램)를 수집하는 과정입니다. 캘리브레이션 데이터가 실제 입력 분포를 대표하지 못하면, activation 스케일이 틀어져서 정확도가 크게 떨어집니다.
권장 가이드
- 최소 수십 배치가 아니라 수백~수천 샘플을 권장합니다(모델/도메인에 따라 상이).
- augmentation이 많은 모델은 augmentation 적용 여부를 학습/추론과 일치시킵니다.
- 분포가 다양한 서비스라면, 캘리브레이션 샘플을 “평균적인 트래픽”으로 구성합니다.
흔한 실수
- train loader를 그대로 써서 label 포함 augmentation이 달라짐
- normalization 누락(예: mean/std 적용 안 함)
- evaluation 단계에서 resize/crop 방식이 달라짐
정확도 급락이 “갑자기” 생겼다면, 모델/코드 변경보다 캘리브레이션 입력 파이프라인 변경이 원인인 경우가 많습니다.
체크리스트 2: 관측기(observer)와 스케일 계산 방식 변경
기본 qconfig는 보통 activation에 MinMaxObserver 혹은 HistogramObserver 계열을 씁니다. 여기서 문제가 되는 전형적인 케이스는 outlier입니다.
- activation에 드물게 큰 값이 섞이면 MinMax 기반 스케일이 커져서, 대부분의 값이 int8에서 뭉개집니다.
해결: activation은 히스토그램 기반, weight는 per-channel
서버 CPU fbgemm 기준으로, 많은 모델에서 아래 조합이 안정적입니다.
- activation:
HistogramObserver(또는 기본이 히스토그램인 qconfig) - weight: per-channel symmetric (특히 Conv/Linear)
import torch
import torch.ao.quantization as tq
torch.backends.quantized.engine = "fbgemm"
act_observer = tq.HistogramObserver.with_args(
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False,
)
wt_observer = tq.PerChannelMinMaxObserver.with_args(
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
)
custom_qconfig = tq.QConfig(
activation=act_observer,
weight=wt_observer,
)
per_channel_symmetric는 weight 양자화에서 정확도 방어에 매우 효과적인 경우가 많습니다.- activation은 데이터 분포에 민감하므로, outlier가 의심되면 히스토그램 계열을 우선 고려합니다.
체크리스트 3: Fusion 실패로 인해 양자화 경로가 깨지는 문제
Conv-BN-ReLU 같은 패턴은 fusion이 되어야 양자화 정확도와 성능이 좋아집니다. fusion이 안 되면:
- BN이 별도로 남아 activation 분포가 달라지고
- quantized op가 기대한 형태로 생성되지 않거나
- dequant-quant가 불필요하게 삽입되어 수치오차가 커질 수 있습니다.
Eager 모드에서는 보통 fuse_modules를 사용합니다.
import torch
import torch.nn as nn
import torch.ao.quantization as tq
class ConvBNReLU(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
m = ConvBNReLU().eval()
# fusion
m_fused = tq.fuse_modules(m, [["conv", "bn", "relu"]], inplace=False)
torch.backends.quantized.engine = "fbgemm"
m_fused.qconfig = tq.get_default_qconfig("fbgemm")
m_prepared = tq.prepare(m_fused, inplace=False)
with torch.inference_mode():
for _ in range(200):
m_prepared(torch.randn(8, 3, 224, 224))
m_int8 = tq.convert(m_prepared, inplace=False)
print(m_int8)
만약 fusion 대상 모듈 이름이 다르거나, inplace=True로 원본 모델을 덮어써서 디버깅이 어려워지면 문제를 찾기 힘듭니다. 급락 상황에서는 inplace=False로 단계별 모델을 출력해 구조를 확인하는 것이 좋습니다.
체크리스트 4: 특정 레이어만 양자화에서 제외하기(부분 양자화)
정확도 급락이 특정 블록에서 발생하는 경우가 있습니다. 예를 들어:
- 첫 번째 stem conv
- 마지막 classifier
- LayerNorm / Softmax 주변
- attention 일부(특히 PTQ에서 민감)
이 경우 “전부 int8”을 고집하기보다, 민감 레이어는 FP32로 남기고 나머지만 int8로 두는 것이 실전적으로 더 낫습니다.
Eager 모드에서는 모듈 단위로 qconfig = None을 줄 수 있습니다.
import torch.ao.quantization as tq
def disable_quant_for_module(module, module_name_keywords=("layernorm", "softmax")):
for name, child in module.named_modules():
lowered = name.lower()
if any(k in lowered for k in module_name_keywords):
child.qconfig = None
# 사용 예
# model.qconfig = custom_qconfig
# disable_quant_for_module(model)
# prepared = tq.prepare(model)
이 접근은 “정확도는 살리고, 속도/비용도 대부분은 절감”하는 절충안으로 자주 사용됩니다.
체크리스트 5: 평가 지표와 비교 방법이 올바른가
int8 변환 후 정확도 비교를 할 때, 아래 실수로 “급락”처럼 보이는 경우가 있습니다.
- FP32는
model.eval()인데 int8은train()상태 - FP32는
torch.inference_mode()인데 int8은 grad가 켜짐 - preprocessing이 서로 다름
- batch size가 달라져 BN 통계가 다르게 작동(특히 fusion 전/후)
비교 루프를 고정하세요.
import torch
def eval_top1(model, dataloader, device="cpu"):
model.eval()
correct = 0
total = 0
with torch.inference_mode():
for x, y in dataloader:
x = x.to(device)
y = y.to(device)
logits = model(x)
pred = logits.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.numel()
return correct / max(total, 1)
FX Graph Mode로 전환할 때의 포인트
PyTorch 2.x에서는 FX 기반(prepare_fx, convert_fx)이 더 강력한 경우가 많습니다. 패턴 매칭과 변환이 더 체계적이고, 대규모 모델에서 “어떤 연산이 양자화됐는지” 추적하기가 비교적 쉽습니다.
다만 FX는 모델이 trace 가능해야 하고, 동적 제어 흐름이 많으면 추가 작업이 필요합니다.
import torch
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
import torch.ao.quantization as tq
torch.backends.quantized.engine = "fbgemm"
qconfig_mapping = tq.QConfigMapping().set_global(tq.get_default_qconfig("fbgemm"))
example_inputs = (torch.randn(1, 3, 224, 224),)
model_fp32 = ...
model_fp32.eval()
prepared = prepare_fx(model_fp32, qconfig_mapping, example_inputs)
with torch.inference_mode():
for _ in range(200):
prepared(torch.randn(8, 3, 224, 224))
model_int8 = convert_fx(prepared)
FX로 바꿨는데 정확도 급락이 완화되는 경우도 있고, 반대로 특정 패턴이 누락되어 악화되는 경우도 있습니다. 중요한 건 “어떤 op가 quantize 되었는지”를 출력으로 확인하면서, 민감 구간을 부분 제외하는 식으로 접근하는 것입니다.
디버깅 팁: 어디서부터 망가지는지 빠르게 찾기
정확도 급락을 한 번에 해결하려고 하면 시간이 오래 걸립니다. 아래 순서가 효율적입니다.
- FP32 vs int8 출력 차이를 레이어별로 비교
- activation outlier 여부 확인(관측기 통계)
- fusion 적용 여부 확인
- 민감 레이어 제외 후 정확도 회복되는지 확인
간단한 형태로는 forward hook으로 중간 텐서를 비교할 수 있습니다.
import torch
def capture_activations(model, layer_names):
acts = {}
hooks = []
name_to_module = dict(model.named_modules())
for ln in layer_names:
m = name_to_module[ln]
def _hook(name):
def fn(mod, inp, out):
# out이 tuple일 수도 있으니 필요 시 분기
acts[name] = out.detach().cpu()
return fn
hooks.append(m.register_forward_hook(_hook(ln)))
return acts, hooks
# 사용 예
# acts_fp32, h1 = capture_activations(model_fp32, ["fc1", "fc2"])
# acts_int8, h2 = capture_activations(model_int8, ["fc1", "fc2"])
레이어별 MSE나 cosine similarity를 보면 “어느 구간부터 오차가 폭발하는지”가 보이고, 그 구간을 중심으로 observer 변경 또는 부분 제외를 적용하면 됩니다.
운영 관점: 재현성과 배포 안정화
PTQ는 데이터 통계에 의존하므로, 운영에서는 다음을 권장합니다.
- 캘리브레이션 샘플을 버전 관리(해시, 날짜, 샘플링 규칙)
- 양자화 설정(qconfig, 제외 레이어 목록, 엔진)을 코드로 고정
- FP32와 int8의 회귀 테스트를 CI에서 자동화
모델 서빙을 카나리로 굴리며 안전하게 전환하는 접근은 GPU/CPU 환경을 막론하고 중요합니다. 배포 파이프라인 관점은 KServe+Istio로 GPU 모델 카나리 배포 실전 가이드도 함께 참고하면, “정확도/지연시간/오류율”을 단계적으로 검증하는 그림을 잡기 좋습니다.
또한 대규모 실험 로그나 캘리브레이션 샘플을 CSV로 쌓다 보면 메모리 이슈가 쉽게 나는데, 이때는 Python Polars로 100GB CSV 메모리 오류 해결 같은 방식으로 분석 파이프라인을 안정화할 수 있습니다.
실전 처방전: 급락했을 때 가장 먼저 해볼 조합
정확도 급락 상황에서 “가장 성공 확률이 높은” 처방전 순서입니다.
- 캘리브레이션 데이터 늘리기 + 입력 전처리 완전 일치
- weight per-channel symmetric 적용
- activation을 HistogramObserver로 변경
- Conv-BN-ReLU 등 fusion 적용 여부 확인
- 민감 레이어(첫/마지막, norm/softmax 주변) 부분 제외
- Eager에서 막히면 FX Graph Mode로 전환
이 순서대로 하면, “원인 미상 급락”을 대부분 구조적으로 해소할 수 있습니다.
마무리
PyTorch 2.x PTQ int8에서 정확도 급락은 흔하지만, 무작정 파라미터를 바꾸기보다 캘리브레이션 분포, 관측기 선택, fusion, per-channel weight, 부분 양자화를 순서대로 점검하면 재현성 있게 해결할 수 있습니다.
다음 단계로는:
- 모델 아키텍처별 권장 qconfig 템플릿 정리
- 레이어별 오차 기반 자동 제외(화이트리스트/블랙리스트)
- PTQ로 한계가 보이면 QAT로 전환
을 고려하면, 비용 절감과 품질을 동시에 달성하기가 훨씬 쉬워집니다.