- Published on
PyTorch 2.x PTQ INT8 양자화 정확도 하락 복구법
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서버 추론 비용을 줄이려고 INT8 양자화를 적용했는데, 기대했던 지연시간 개선은 얻었지만 정확도가 눈에 띄게 떨어지는 경우가 많습니다. 특히 PyTorch 2.x 계열에서는 torch.compile/그래프 변환, torch.ao.quantization(PTQ), 백엔드(x86 FBGEMM, ARM QNNPACK/XNNPACK) 조합에 따라 “양자화는 됐는데 왜 이렇게 망가지지?” 같은 상황이 자주 발생합니다.
이 글은 PyTorch 2.x에서 PTQ(Post-Training Quantization) 로 INT8을 적용했을 때 정확도 하락을 복구하는 방법을, 원인별로 분해해서 설명합니다. 결론부터 말하면 대부분의 정확도 이슈는 아래 4가지로 귀결됩니다.
- 캘리브레이션 데이터/루프가 부정확하거나 부족함
- 관측자(observer) 설정이 모델/데이터 분포에 맞지 않음
- 연산자 조합(Conv/Linear, activation, residual, concat 등)에서 per-tensor 양자화가 병목이 됨
- 폴딩/퓨전/패딩/정규화 등 전처리와 그래프 변환이 “학습 시 의미”와 달라짐
운영 환경에서 재현하기 어려운 문제는 로그/헬스체크가 핵심인데, 배포 파이프라인에서 모델 로딩이 반복 재시작되거나 프로브에 걸려 장애로 보일 수도 있습니다. 그런 경우에는 별도로 K8s CrashLoopBackOff 원인별 로그·Probe 해결 가이드 같은 글의 체크리스트도 함께 참고하면 원인 분리가 빨라집니다.
PyTorch 2.x PTQ INT8의 기본 흐름 정리
PyTorch PTQ는 크게 다음 순서입니다.
- FP32 모델 준비(가급적
eval()고정) - quantization config 지정(
qconfig또는qconfig_mapping) prepare단계에서 observer 삽입- 캘리브레이션 데이터로 forward 수행(통계 수집)
convert단계에서 실제INT8모듈로 치환
PyTorch 2.x에서는 FX Graph Mode(권장)가 일반적입니다.
torch.ao.quantization.quantize_fx.prepare_fxtorch.ao.quantization.quantize_fx.convert_fx
정확도 복구는 대부분 prepare_fx 이전의 모델 구조/전처리 정리, qconfig/observer 튜닝, 캘리브레이션 루프 정교화로 해결됩니다.
가장 흔한 정확도 급락 원인 1: 캘리브레이션이 “진짜 입력 분포”를 못 담음
PTQ는 학습 없이 통계만 보고 스케일/제로포인트를 잡습니다. 즉 캘리브레이션 데이터가 실제 트래픽 분포와 다르면, activation 범위가 틀어져서 saturation/클리핑이 발생하고 정확도가 급락합니다.
체크리스트
- 캘리브레이션 샘플 수가 너무 적지 않은가
- 분류 기준으로는 최소 수백~수천 장, NLP는 수천 문장 이상을 권장하는 경우가 많습니다.
- 전처리 파이프라인이 학습/추론과 동일한가
- resize, crop, normalize, tokenization, padding 전략이 다르면 통계가 달라집니다.
model.eval()인가- dropout, batchnorm이 학습 모드면 통계가 흔들립니다.
- 캘리브레이션 시
torch.no_grad()인가- 불필요한 그래프 구성은 성능뿐 아니라 실수 가능성을 높입니다.
캘리브레이션 루프 예시
아래 예시는 이미지 분류 모델을 가정합니다.
import torch
from torch.utils.data import DataLoader
def calibrate(model, dataloader, device="cpu", max_batches=None):
model.eval()
model.to(device)
with torch.no_grad():
for i, (x, _) in enumerate(dataloader):
if max_batches is not None and i >= max_batches:
break
x = x.to(device)
_ = model(x)
포인트는 prepare_fx 이후에 위 루프를 돌려 observer 통계를 충분히 쌓는 것입니다.
가장 흔한 정확도 급락 원인 2: MinMax observer가 outlier에 취약
기본 설정(예: MinMax 기반)은 outlier에 매우 취약합니다. activation에 드물게 큰 값이 섞이면 범위가 과도하게 커지고, 대부분의 값이 INT8의 좁은 구간에 몰려 양자화 오차가 커집니다.
이때 정확도 복구에 가장 효과적인 전략이 Histogram observer 또는 Percentile/MovingAverage 기반을 선택하는 것입니다.
관측자 선택 가이드
- outlier가 많은 분포(특히 ReLU 이후 activation, attention score 등)
- Histogram 기반을 우선 고려
- 분포가 안정적이고 outlier가 적음
- MovingAverageMinMax도 좋은 타협
FX Graph Mode에서 qconfig 예시
아래는 x86 서버에서 흔히 쓰는 fbgemm 백엔드를 가정합니다.
import torch
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
backend = "fbgemm"
qconfig_mapping = get_default_qconfig_mapping(backend)
# 예시 입력(Shape는 실제 모델 입력과 동일해야 함)
example_inputs = (torch.randn(1, 3, 224, 224),)
model_fp32 = model_fp32.eval()
prepared = prepare_fx(model_fp32, qconfig_mapping, example_inputs)
# 캘리브레이션
calibrate(prepared, calib_loader, device="cpu", max_batches=200)
model_int8 = convert_fx(prepared)
여기서 정확도가 크게 떨어진다면, get_default_qconfig_mapping만 믿지 말고 observer를 커스터마이즈해야 합니다.
정확도 복구 핵심 1: per-channel weight 양자화 강제
Conv/Linear weight는 per-tensor로 양자화하면 채널별 스케일 차이를 못 따라가서 손실이 커집니다. 대부분의 백엔드에서 per-channel weight quant가 정확도에 유리합니다.
- Conv2d/Linear weight는 per-channel
- activation은 per-tensor(대부분의 런타임 제약)
PyTorch 기본 qconfig가 이를 포함하는 경우가 많지만, 모델 구조/백엔드에 따라 누락되거나 특정 연산이 fallback 되기도 합니다.
연산자별 양자화 적용 범위 확인
변환 후에 특정 레이어가 float로 남거나, 예상치 못한 dequant-quant가 삽입되면 정확도와 성능이 동시에 나빠집니다. 변환된 그래프를 출력해 확인합니다.
print(model_int8)
또는 FX graph를 확인할 수 있으면(환경에 따라) quant/dequant 위치를 점검하세요.
정확도 복구 핵심 2: “문제 레이어만” FP16/FP32로 남기는 하이브리드 전략
PTQ에서 모든 연산을 INT8로 밀어붙이면, 특정 블록(예: 첫 Conv, 마지막 FC, LayerNorm 주변, residual merge)이 민감하게 반응해 정확도가 크게 흔들릴 수 있습니다.
이때는 선택적 양자화(exclude/override) 로 민감 레이어를 float로 남기는 것이 실전에서 매우 자주 쓰는 복구법입니다.
- 첫 레이어: 입력 분포가 다양해 activation 양자화 손실이 큼
- 마지막 레이어: logit margin이 줄어 top-1이 흔들림
- residual add 지점: scale mismatch로 오차가 누적
FX Graph Mode에서는 모듈 이름 패턴으로 qconfig를 끄는 방식이 가능합니다.
from torch.ao.quantization import QConfigMapping
qconfig_mapping = (
QConfigMapping()
.set_global(torch.ao.quantization.get_default_qconfig("fbgemm"))
# 예: classifier는 float로 유지
.set_module_name("classifier", None)
)
모듈 이름은 print(model_fp32)로 확인해 실제 경로에 맞춰야 합니다. 이 방식은 “완전한 INT8”은 아니지만, 정확도 복구 대비 성능 손실이 작아 현실적인 타협점이 됩니다.
정확도 복구 핵심 3: activation quantization을 줄이는 패턴(특히 residual)
Residual 구조에서 add 직전 텐서들이 서로 다른 스케일을 갖는데, INT8에서 이를 맞추는 과정에서 오차가 커질 수 있습니다. 프레임워크가 내부적으로 rescale을 넣지만, 모델에 따라 민감도가 큽니다.
실전 팁:
- residual branch 중 하나에만 quant/dequant가 과도하게 들어가면 정확도 손실이 커집니다.
- concat 이후 채널 수가 커지는 지점은 activation 분포가 바뀌어 outlier가 늘 수 있습니다.
이런 지점은 앞서 말한 “부분 float 유지”로 우회하는 게 빠릅니다.
정확도 복구 핵심 4: BatchNorm folding과 fusion 확인
Conv + BN + ReLU 같은 패턴은 양자화 전에 fusion/folding이 일어나야 정확도와 성능이 좋아집니다. 그런데 모델이 커스텀 모듈로 감싸져 있거나, 그래프 변환이 예상대로 안 되면 fusion이 깨져서
- 불필요한 quant/dequant 삽입
- BN이 float로 남아 연산 경계가 증가
같은 문제가 생깁니다.
점검 방법
- 변환 후 모델에서 fused 모듈로 바뀌었는지 확인
- 기대한 연산자들이 quantized kernel로 매핑되는지 확인
환경 의존성이 강해서, 배포 환경에서만 성능/정확도가 흔들린다면 네트워크 타임아웃이나 노드 이슈처럼 보일 수 있습니다. 클러스터에서 추론 서버가 간헐적으로 느려지는 식이라면 EKS TLS handshake timeout 원인·해결 9가지처럼 “증상은 네트워크인데 원인은 CPU 스로틀/리소스”인 케이스도 같이 의심해보는 게 좋습니다.
실제 복구 절차: 정확도 하락을 단계적으로 줄이는 플로우
아래 순서대로 적용하면 원인을 빠르게 좁힐 수 있습니다.
1) FP32 vs INT8 정확도 격차를 먼저 수치화
- top-1/top-5, F1, BLEU 등 대표 지표를 동일 evaluation 코드로 측정
- 샘플링이 아니라 전체 검증셋으로 1회는 꼭 측정
2) 캘리브레이션 데이터와 루프를 먼저 고친다
- 샘플 수 증가
- 실제 트래픽 분포 반영
- 전처리 동일성 보장
3) observer를 MinMax에서 Histogram 계열로 변경
- outlier로 인한 scale 붕괴를 완화
- 특히 activation에서 효과가 큼
4) per-channel weight quant가 적용되는지 확인
- Conv/Linear 중심
- 적용이 안 되면 qconfig를 명시적으로 지정
5) 민감 레이어를 float로 제외(하이브리드)
- 첫/마지막 레이어부터 시도
- residual merge 주변 블록을 부분적으로 제외
6) 그래프 상 quant/dequant 경계가 과도한지 확인
- 경계가 많을수록 오차 누적과 성능 저하 가능성 증가
재현 가능한 예제: FX PTQ 파이프라인 스켈레톤
아래 코드는 “준비-캘리브레이션-변환-평가”의 최소 골격입니다. 모델/데이터셋만 교체하면 바로 실험 루프를 만들 수 있습니다.
import torch
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
def evaluate_top1(model, dataloader, device="cpu"):
model.eval()
model.to(device)
correct = 0
total = 0
with torch.no_grad():
for x, y in dataloader:
x = x.to(device)
y = y.to(device)
logits = model(x)
pred = logits.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.numel()
return correct / max(total, 1)
backend = "fbgemm"
qconfig_mapping = get_default_qconfig_mapping(backend)
example_inputs = (torch.randn(1, 3, 224, 224),)
model_fp32 = model_fp32.eval()
acc_fp32 = evaluate_top1(model_fp32, val_loader)
prepared = prepare_fx(model_fp32, qconfig_mapping, example_inputs)
calibrate(prepared, calib_loader, max_batches=200)
model_int8 = convert_fx(prepared)
acc_int8 = evaluate_top1(model_int8, val_loader)
print({"fp32": acc_fp32, "int8": acc_int8, "drop": acc_fp32 - acc_int8})
이 상태에서 drop이 크면, 앞 절의 복구 전략을 하나씩 적용하면서 “어떤 조치가 drop을 줄였는지”를 실험 로그로 남기면 됩니다.
운영 관점 팁: 정확도 이슈를 성능/장애 이슈와 분리하기
양자화 후 정확도 하락은 모델 품질 문제지만, 현업에서는 종종 다음과 섞여 보입니다.
- 일부 요청만 느려짐(quant/dequant 경계 증가, fallback kernel)
- 특정 노드에서만 오동작(백엔드/CPU feature 차이)
- 재시작 루프/헬스체크 실패(모델 로딩 지연)
이럴 때는 애플리케이션 레벨의 원인과 인프라 레벨의 증상을 분리해 접근해야 합니다. 서비스 프로세스가 재시작을 반복한다면 systemd 재시작 루프 - ExecStart 디버깅 가이드처럼 “실패 지점의 로그를 먼저 고정”하는 방식이 도움이 됩니다.
마무리: PTQ INT8 정확도 복구의 우선순위
PyTorch 2.x에서 PTQ로 INT8을 적용했을 때 정확도 하락을 복구하는 가장 효율적인 우선순위는 다음입니다.
- 캘리브레이션 데이터/전처리/루프를 실제 추론과 동일하게 맞춘다
- outlier에 강한 observer로 바꾼다(특히 activation)
- per-channel weight quant 적용 여부를 확인한다
- 민감 레이어는 과감히 float로 남겨 하이브리드로 타협한다
- 변환 후 그래프에서 quant/dequant 경계를 줄이고 fallback을 제거한다
PTQ는 “학습 없이” 성능을 얻는 대신, 통계 품질과 그래프 변환 품질에 매우 민감합니다. 위 체크리스트를 순서대로 적용하면, 많은 경우 INT8의 이점을 유지하면서도 정확도 손실을 실무 허용 범위로 되돌릴 수 있습니다.