- Published on
PyTorch 2.x PTQ로 INT8 양자화 정확도 지키기
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙 비용을 줄이기 위해 INT8 양자화를 붙였는데, 모델 정확도가 눈에 띄게 떨어져 롤백한 경험이 있다면 PTQ 설정과 캘리브레이션 과정이 원인일 가능성이 큽니다. PyTorch 2.x는 torch.ao.quantization(AOQuant) 기반으로 PTQ 파이프라인이 정리되어 있고, 백엔드(x86 fbgemm, ARM qnnpack)에 맞는 qconfig와 관측기(observer) 선택, 캘리브레이션 데이터 품질, 레이어별 예외 처리만 제대로 해도 정확도 하락을 상당히 줄일 수 있습니다.
이 글은 PyTorch 2.x에서 Eager Mode PTQ를 기준으로, INT8 양자화 시 흔히 터지는 정확도 이슈를 “왜”와 “어떻게” 관점에서 정리합니다.
PTQ에서 정확도가 떨어지는 대표 원인
1) 캘리브레이션 데이터가 실제 분포를 못 따라감
PTQ는 학습 없이 관측기(Observer)가 activation 통계를 수집해 스케일과 제로포인트를 정합니다. 이때 캘리브레이션 데이터가 실제 트래픽과 분포가 다르면, activation 범위가 과소 또는 과대 추정되어 양자화 오차가 커집니다.
실전 팁:
- 최소 수백~수천 샘플을 권장(모델/도메인에 따라 다름)
- 전처리(정규화, 리사이즈, 토크나이즈)가 서빙과 100% 동일해야 함
- “쉬운 샘플”만 모으면 분산이 줄어들어 오히려 더 나빠질 수 있음
2) 관측기/양자화 스킴 선택이 모델 특성과 불일치
대표적으로 activation을 per_tensor_affine로 잡아버리면 채널별 스케일 차이가 큰 모델에서 오차가 커집니다. 반면 weight는 대개 per_channel이 유리합니다.
또한 스킴에 따라 대칭(symmetric) 또는 비대칭(asymmetric) 양자화가 결정되는데, activation은 비대칭이 일반적으로 유리하고 weight는 대칭이 유리한 경우가 많습니다.
3) 레이어 퓨전(fuse) 누락
Conv-BN-ReLU 같은 패턴을 fuse하지 않으면 BN이 별도 op로 남아 activation 분포가 달라지고 quant/dequant 지점이 늘어나 성능과 정확도 모두 손해를 봅니다.
4) 민감 레이어까지 무리하게 INT8로 내림
첫 번째 Conv, 마지막 FC(또는 classifier head), attention 관련 연산 일부는 INT8에서 민감하게 흔들릴 수 있습니다. PTQ는 “전체 일괄 INT8”보다 “부분 제외”가 정확도 방어에 효과적일 때가 많습니다.
5) 백엔드/하드웨어 불일치
x86 서버에서 fbgemm로 캘리브레이션하고 ARM 환경에서 qnnpack로 돌리면 커널 특성 차이로 수치가 달라질 수 있습니다. 가능하면 목표 환경과 같은 백엔드로 검증하세요.
PyTorch 2.x PTQ 기본 파이프라인
PTQ는 크게 다음 순서입니다.
- 모델을
eval()로 전환 - fuse 가능한 모듈 fuse
qconfig설정prepare로 관측기 삽입- 캘리브레이션 데이터로 추론 돌려 통계 수집
convert로 실제 INT8 연산 모듈로 변환- 정확도/성능 측정
아래 예시는 이미지 분류 같은 CNN을 가정한 템플릿입니다.
import copy
import torch
import torch.nn as nn
import torch.ao.quantization as aq
# 예시용 더미 모델 (실전에서는 torchvision 모델 등 사용)
class SmallCNN(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, 3, stride=2, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.Conv2d(16, 32, 3, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, 10)
def forward(self, x):
x = self.features(x)
x = self.pool(x).flatten(1)
return self.fc(x)
def fuse_model(model: nn.Module) -> nn.Module:
# Conv-BN-ReLU 패턴 fuse
aq.fuse_modules(model.features, [['0', '1', '2'], ['3', '4', '5']], inplace=True)
return model
def ptq_int8(model_fp32, calib_loader, backend='fbgemm'):
torch.backends.quantized.engine = backend
model = copy.deepcopy(model_fp32).eval()
model = fuse_model(model)
# 기본 qconfig (백엔드에 맞는 관측기/스킴 세트)
model.qconfig = aq.get_default_qconfig(backend)
prepared = aq.prepare(model, inplace=False)
# 캘리브레이션: 통계 수집
with torch.inference_mode():
for images, _ in calib_loader:
prepared(images)
quantized = aq.convert(prepared, inplace=False)
return quantized
이 템플릿은 “동작은 하는” 수준입니다. 정확도를 지키려면 다음 섹션의 방어 전략을 적용해야 합니다.
정확도 하락을 막는 핵심 전략 7가지
1) activation 관측기: HistogramObserver로 범위 추정 안정화
기본 관측기는 min/max 기반이라 outlier에 취약할 수 있습니다. outlier가 많은 데이터(예: 자연어 logits, 일부 detection feature map)에서는 히스토그램 기반이 더 안정적인 경우가 많습니다.
import torch.ao.quantization as aq
act_observer = aq.HistogramObserver.with_args(
reduce_range=False
)
qconfig = aq.QConfig(
activation=act_observer,
weight=aq.default_per_channel_weight_observer
)
model.qconfig = qconfig
포인트:
- weight는
default_per_channel_weight_observer로 채널별 스케일을 쓰는 편이 일반적으로 유리 - activation은 히스토그램 기반을 시도해보고, 정확도/지연시간 트레이드오프를 확인
2) 캘리브레이션 샘플 수를 늘리고, 분포를 “서빙처럼” 만들기
정확도 하락이 크면 대부분 여기서 갈립니다.
실전 체크:
- 데이터 증강을 캘리브레이션에 넣지 말고(서빙과 다르므로), 서빙 전처리만 적용
- 길이가 다른 입력(문장 길이, 이미지 해상도 등)이 섞인다면 그 분포를 반영
- 클래스 불균형이 심하면, 캘리브레이션이 특정 클래스에 치우치지 않게 샘플링
3) 첫 레이어/마지막 레이어는 FP32로 남겨두기
INT8로 내리면 이득이 크지 않은데 정확도는 크게 흔들릴 수 있는 구간입니다. 모듈별로 qconfig=None로 제외할 수 있습니다.
def disable_quant_for_sensitive_layers(model: nn.Module):
# 예: 첫 Conv와 마지막 FC를 제외
model.features[0].qconfig = None
model.fc.qconfig = None
model = fuse_model(model).eval()
model.qconfig = aq.get_default_qconfig('fbgemm')
disable_quant_for_sensitive_layers(model)
prepared = aq.prepare(model, inplace=False)
실전에서는 “정확도 민감 레이어”를 찾기 위해 레이어별 ablation을 합니다.
- head만 FP32로 남겨보기
- stem(초기 feature extractor)만 FP32로 남겨보기
- 특정 block만 FP32로 남겨보기
4) 연산자 단위로 quant/dequant 경계를 줄이기
양자화 경계가 많으면 수치 오차가 누적됩니다. fuse를 최대한 하고, 가능하면 quantizable한 블록 구조를 유지하세요.
예:
- Conv-BN-ReLU는 fuse
- residual add가 있는 경우, 양자화 친화적인 패턴으로 유지(가능하면 PyTorch quantization이 지원하는 형태)
5) 백엔드에 맞는 설정을 고정하고, 재현 가능하게 측정하기
- x86 서버:
fbgemm - ARM(모바일/엣지):
qnnpack
또한 성능/정확도 측정 시 다음을 고정하세요.
model.eval()torch.inference_mode()- 스레드 수(
torch.set_num_threads) - 입력 배치/shape
6) 정확도만 보지 말고 “출력 분포”를 비교해 문제 지점을 찾기
Top-1만 보면 어디서 틀어졌는지 감이 안 옵니다. FP32 vs INT8의 logits 분포 차이를 보면 outlier나 saturation을 빠르게 의심할 수 있습니다.
import torch
import torch.nn.functional as F
def compare_logits(model_fp32, model_int8, x):
model_fp32.eval(); model_int8.eval()
with torch.inference_mode():
y0 = model_fp32(x)
y1 = model_int8(x)
p0 = F.softmax(y0, dim=-1)
p1 = F.softmax(y1, dim=-1)
kl = (p0 * (p0.clamp_min(1e-9).log() - p1.clamp_min(1e-9).log())).sum(dim=-1).mean()
max_abs = (y0 - y1).abs().max()
return float(kl), float(max_abs)
- KL divergence가 특정 입력에서만 튀면, 그 입력의 activation 범위를 관측기가 제대로 못 잡았을 수 있습니다.
7) PTQ로 안 되면 “최소 비용 QAT”를 고려
PTQ로 정확도 하락을 1~2% 이내로 못 막는 모델이 있습니다(특히 transformer 계열). 이때는 짧은 QAT(Quantization Aware Training)로 보정하는 편이 총비용이 더 낮을 수 있습니다.
PTQ로 끝내야 한다면:
- 민감 레이어 제외
- 더 나은 관측기
- 더 좋은 캘리브레이션 이 3가지를 먼저 끝까지 해보는 것이 순서입니다.
실전 예제: ResNet 계열에서 흔한 설정
아래는 “기본 PTQ + fuse + 일부 레이어 제외 + 커스텀 observer”를 합친 예시입니다.
import copy
import torch
import torch.ao.quantization as aq
import torchvision
def fuse_resnet(model):
# torchvision resnet은 fuse 가능한 유틸이 따로 없으니,
# 일반적으로는 torch.ao.quantization.fuse_modules를 블록 단위로 적용합니다.
# 여기서는 예시로 첫 conv/bn/relu만 fuse
aq.fuse_modules(model, [['conv1', 'bn1', 'relu']], inplace=True)
return model
def build_qconfig():
act = aq.HistogramObserver.with_args(reduce_range=False)
wt = aq.default_per_channel_weight_observer
return aq.QConfig(activation=act, weight=wt)
def ptq_resnet18(calib_loader, backend='fbgemm'):
torch.backends.quantized.engine = backend
fp32 = torchvision.models.resnet18(weights=None).eval()
model = copy.deepcopy(fp32)
fuse_resnet(model)
model.qconfig = build_qconfig()
# 민감할 수 있는 head 제외(상황에 따라 conv1도 제외 고려)
model.fc.qconfig = None
prepared = aq.prepare(model, inplace=False)
with torch.inference_mode():
for images, _ in calib_loader:
prepared(images)
int8_model = aq.convert(prepared, inplace=False)
return fp32, int8_model
주의:
- ResNet의 모든 블록을 제대로 fuse하려면 BasicBlock 내부의 conv/bn/relu를 블록 단위로 fuse해야 합니다. 모델 구조에 따라 fuse 포인트가 다르므로, 본인 모델의 모듈 이름을 찍어서 정확히 지정하세요.
디버깅 체크리스트
캘리브레이션이 제대로 됐는지
- 캘리브레이션을 10개로 했을 때와 1,000개로 했을 때 정확도 차이가 큰가
- 특정 입력군(야간 이미지, 긴 문장, 특정 도메인)에서만 급락하는가
레이어별로 어디가 민감한지
- head를 FP32로 두면 회복되는가
- stem을 FP32로 두면 회복되는가
- 특정 block만 FP32로 두면 회복되는가
배포 환경 이슈(성능/안정성)
INT8로 바꾸는 이유가 대개 “지연시간/비용”인데, 배포 환경에서 병목이 다른 곳에 있으면 기대만큼 이득이 안 날 수 있습니다. 예를 들어 노드 메모리 압박으로 OOM이 나면 최적화 이전에 안정성부터 잡아야 합니다. 이런 유형의 장애 패턴은 Kubernetes CrashLoopBackOff와 OOMKilled(ExitCode 137) 해결 같은 운영 관점 글과 함께 점검하면 좋습니다.
마무리: “관측기 + 캘리브레이션 + 예외처리”가 80%를 해결
PyTorch 2.x PTQ에서 INT8 정확도 하락을 막는 가장 현실적인 접근은 다음 3단계입니다.
- fuse를 최대한 적용해 불필요한 quant/dequant 경계를 줄이기
- 캘리브레이션 데이터를 서빙 분포에 맞추고 충분히 확보하기
- activation 관측기를 히스토그램 기반으로 바꾸고, 민감 레이어는 FP32로 제외하기
이 조합으로도 목표 정확도를 못 맞추면, PTQ의 한계일 수 있으니 짧은 QAT 또는 혼합 정밀도 전략(FP16 또는 일부 INT8)을 검토하는 것이 시간 대비 효율적입니다.