- Published on
PyTorch PTQ로 INT8 변환해 정확도 지키기
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙 환경에서 지연 시간과 메모리 사용량을 줄이기 위해 INT8 양자화는 가장 먼저 검토되는 최적화 옵션입니다. 특히 PTQ(Post-Training Quantization)는 재학습 없이 적용할 수 있어 도입 장벽이 낮지만, 설정을 대충 하면 정확도가 눈에 띄게 떨어지기도 합니다. 이 글에서는 PyTorch에서 PTQ로 INT8 변환을 수행할 때 정확도를 유지하기 위해 반드시 챙겨야 할 체크리스트와 코드 패턴을 정리합니다.
참고로 4bit 같은 더 공격적인 양자화를 시도하다가 실패하는 케이스도 많은데, 원인과 대응은 별도 글인 PyTorch 모델 4bit 양자화 실패 7가지와 해법에서도 함께 확인할 수 있습니다.
PTQ INT8의 핵심: 무엇이 정확도를 깎는가
PTQ는 학습된 가중치와 활성값을 낮은 비트로 근사합니다. 이때 정확도 하락은 대개 아래에서 발생합니다.
- 캘리브레이션 데이터 분포 불일치: 관측자가 본 활성값 범위가 실제 트래픽과 다르면 스케일이 틀어집니다.
- 관측자/스케일링 방식 선택 오류:
minmax기반은 outlier에 취약하고,histogram기반은 계산 비용이 더 들지만 안정적인 경우가 많습니다. - 레이어별 민감도 차이 무시: 첫 Conv, 마지막 Linear, attention의 projection 계열은 특히 민감할 수 있습니다.
- 연산자 폴딩 및 그래프 변환 문제: Conv-BN-ReLU 패턴의 폴딩 여부, fusion 타이밍이 결과에 영향을 줍니다.
- 백엔드(qengine) 불일치: 서버 CPU에서는 보통
fbgemm, ARM 계열은qnnpack가 성능과 정확도에 영향을 줍니다.
이 글의 목표는 “일단 INT8로 바꾸기”가 아니라 “정확도 하락을 측정하고, 원인을 좁혀가며, 최소한의 희생으로 성능을 얻는 방법”입니다.
어떤 양자화 API를 쓸 것인가: FX Graph Mode 권장
PyTorch에는 크게 2가지 흐름이 있습니다.
- Eager mode quantization: 모듈에 직접
QuantStub등을 붙이거나 수동 설정이 많음 - FX Graph Mode quantization: 모델 그래프를 추적해 패턴 매칭과 변환을 자동화
최근 실무에서는 FX 기반이 재현성과 유지보수 측면에서 유리합니다. 아래 예제도 FX Graph Mode를 기준으로 설명합니다.
준비: 백엔드와 평가 루프 고정
정확도 비교는 조건이 조금만 달라도 흔들립니다. 아래는 최소한의 고정 사항입니다.
model.eval()torch.no_grad()- 동일한 전처리/후처리
- 동일한 평가 데이터셋
- CPU에서 INT8을 돌릴 계획이면 평가도 CPU에서 수행
또한 qengine을 명시합니다.
import torch
# 서버 x86 CPU에서 일반적으로 사용
torch.backends.quantized.engine = "fbgemm"
# ARM/모바일 계열이면 보통 qnnpack
# torch.backends.quantized.engine = "qnnpack"
캘리브레이션 데이터: “몇 개”가 아니라 “어떤 분포”가 중요
PTQ에서 캘리브레이션은 관측자(observer)가 활성값 통계를 수집하는 과정입니다. 여기서 가장 흔한 실수는 “학습 데이터 일부를 대충 넣기”입니다.
정확도 유지를 위해서는 아래 원칙을 권장합니다.
- 실제 서빙 입력 분포를 대표하는 샘플을 포함
- 길이/해상도/밝기/노이즈 등 변동이 큰 축을 커버
- 분류/검출/LLM 등 태스크에 따라 outlier가 자주 등장하는 패턴을 포함
대략적인 샘플 수 가이드는 모델과 태스크에 따라 다르지만, 경험적으로는 다음이 출발점이 됩니다.
- 이미지 분류: 200~1,000 샘플
- 검출/세그: 500~2,000 샘플
- NLP(Transformer): 문장 길이 분포를 맞춘 500~5,000 샘플
중요한 점은 “샘플 수를 늘리면 무조건 좋아진다”가 아니라, 대표성 없는 데이터 1만 개보다 대표성 있는 500개가 낫다는 것입니다.
실전 코드: FX PTQ로 INT8 변환
아래 예제는 PyTorch의 FX Graph Mode로 PTQ를 적용하는 기본 골격입니다. 모델에 따라 qconfig_mapping과 example_inputs가 달라질 수 있습니다.
import torch
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
def ptq_int8_fx(model, calib_loader, example_inputs):
model = model.eval()
# 백엔드에 맞는 기본 qconfig
qconfig_mapping = get_default_qconfig_mapping(torch.backends.quantized.engine)
# 1) 준비: 관측자 삽입
prepared = prepare_fx(model, qconfig_mapping, example_inputs=example_inputs)
# 2) 캘리브레이션: 관측자가 통계 수집
with torch.no_grad():
for batch in calib_loader:
# 배치 형태는 프로젝트에 맞게 조정
if isinstance(batch, (list, tuple)):
x = batch[0]
else:
x = batch
prepared(x)
# 3) 변환: INT8 커널로 교체
quantized = convert_fx(prepared)
return quantized
example_inputs를 왜 넣어야 하나
FX는 그래프를 추적할 때 입력 텐서의 형태와 dtype을 참고합니다. example_inputs는 단순히 더미가 아니라, 변환 가능한 연산 패턴을 제대로 잡기 위한 “스펙”입니다.
예를 들어 이미지 모델이라면 다음처럼 지정합니다.
example_inputs = (torch.randn(1, 3, 224, 224),)
Transformer 계열은 input ids, attention mask 등 복수 입력이 되므로 모델 forward 시그니처에 맞춰야 합니다.
정확도 유지의 핵심 1: 관측자(Observer)와 스케일링 전략
기본 qconfig는 빠르게 시작하기에는 좋지만, 정확도 문제가 생기면 관측자 설정부터 의심해야 합니다.
대표적으로 다음 선택지가 있습니다.
- Per-tensor vs Per-channel: 가중치는 per-channel이 정확도에 유리한 경우가 많습니다.
- MinMaxObserver vs HistogramObserver: outlier가 있으면 histogram이 안정적일 수 있습니다.
- 활성값 symmetric vs asymmetric: ReLU 이후는 비대칭이 유리한 경우가 많고, 특정 백엔드는 대칭을 선호하기도 합니다.
아래는 커스텀 qconfig를 구성하는 예시입니다. 코드에서 부등호 문자가 본문에 노출되지 않도록 모두 코드 블록 안에 넣었습니다.
import torch
from torch.ao.quantization import QConfig
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
act_observer = HistogramObserver.with_args(
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False,
)
wt_observer = PerChannelMinMaxObserver.with_args(
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
)
custom_qconfig = QConfig(activation=act_observer, weight=wt_observer)
이후 qconfig_mapping에 반영합니다.
from torch.ao.quantization import QConfigMapping
qconfig_mapping = (
QConfigMapping()
.set_global(custom_qconfig)
)
정확도 유지의 핵심 2: 레이어별 예외 처리(부분 양자화)
모든 레이어를 INT8로 바꾸는 것이 항상 최선은 아닙니다. 특히 아래는 FP32로 남겨두는 것이 정확도에 도움이 되는 경우가 많습니다.
- 입력단 첫 Conv
- 출력단 마지막 Linear
- 특정 attention block의 projection
- LayerNorm, Softmax 등 (대개 양자화 이득이 작거나 변환이 까다로움)
FX에서는 모듈 이름 패턴으로 제외할 수 있습니다. 예시는 다음과 같습니다.
from torch.ao.quantization import float_qparams_weight_only_qconfig
qconfig_mapping = (
QConfigMapping()
.set_global(custom_qconfig)
# 예: 마지막 분류기 레이어는 제외
.set_module_name("classifier", None)
)
모듈 이름을 정확히 지정하려면 다음처럼 모델의 named_modules를 먼저 확인합니다.
for name, m in model.named_modules():
if "classifier" in name or "head" in name:
print(name, type(m))
부분 양자화는 “정확도는 지키면서도 대부분의 연산을 INT8로 가져가는” 현실적인 절충안입니다.
정확도 유지의 핵심 3: 캘리브레이션을 평가처럼 다뤄라
캘리브레이션은 단순 forward가 아니라 “통계 수집 실험”입니다. 다음을 권장합니다.
- 캘리브레이션 중 입력 범위 로그(최소/최대, 퍼센타일)를 수집
- 대표 샘플을 바꿔가며 결과 변동 확인
- outlier가 많은 입력이 있는지 점검
간단히 분포를 보는 방법은 관측자 모듈을 훑어보는 것입니다. 변환 전 prepared 상태에서 observer를 찾아 통계를 확인할 수 있습니다.
import torch
for name, m in prepared.named_modules():
if "activation_post_process" in name:
# observer 내부 상태는 타입에 따라 다르지만
# min_val/max_val을 제공하는 경우가 많습니다.
if hasattr(m, "min_val") and hasattr(m, "max_val"):
print(name, m.min_val, m.max_val)
이 과정에서 특정 레이어만 범위가 비정상적으로 크다면, 그 레이어를 FP32로 남기거나(부분 양자화), activation observer를 histogram으로 바꾸는 식으로 대응합니다.
정확도 유지의 핵심 4: 측정 지표를 “최종 정확도” 하나로만 두지 않기
INT8 변환 후 정확도가 떨어졌을 때 원인을 좁히려면, 최종 top-1 같은 지표만 보면 시간이 오래 걸립니다. 아래 보조 지표를 함께 두면 진단이 빨라집니다.
- 레이어별 출력 MSE 또는 cosine similarity
- 대표 샘플에 대한 로짓 분포 변화
- 클래스별 precision/recall 변화
간단한 로짓 cosine 비교 예시는 다음과 같습니다.
import torch
import torch.nn.functional as F
@torch.no_grad()
def compare_logits(fp32_model, int8_model, x):
fp32_model.eval()
int8_model.eval()
a = fp32_model(x)
b = int8_model(x)
a = a.flatten(1)
b = b.flatten(1)
cos = F.cosine_similarity(a, b, dim=1).mean().item()
mse = F.mse_loss(a, b).item()
return {"cosine": cos, "mse": mse}
cosine이 특정 입력에서 급락한다면, 그 입력이 캘리브레이션에 포함되지 않았거나 outlier에 취약한 설정일 가능성이 큽니다.
배포 관점 팁: INT8은 CPU에서 특히 효과적이다
GPU에서는 FP16, BF16이 더 흔한 최적화 경로이고, INT8은 엔진과 커널 지원에 따라 편차가 큽니다. 반면 CPU 서빙에서는 INT8이 지연 시간과 비용에 직접적인 이득을 주는 경우가 많습니다.
서빙 비용 최적화는 모델만의 문제가 아니라 인프라와도 맞물립니다. 대규모 운영에서 비용을 줄이는 관점은 EKS 비용 40%↓ - Karpenter+Graviton 전환 실전도 함께 참고하면, “모델 최적화와 노드 최적화”를 같이 설계하는 데 도움이 됩니다.
또한 로컬에서 LLM을 돌리며 메모리 병목을 해결하는 접근은 INT8과는 결이 다르지만, 시스템적으로 비슷한 의사결정(정확도 vs 비용 vs 지연)을 요구합니다. 관련해서는 Transformers 로컬 LLM OOM - 4bit+오프로딩도 같이 보면 좋습니다.
자주 발생하는 실패 패턴과 처방
1) 정확도가 소폭이 아니라 크게 떨어진다
- 캘리브레이션 데이터가 너무 적거나 분포가 다름
- activation observer를
HistogramObserver로 변경 - 민감 레이어를 FP32로 제외
2) 특정 클래스에서만 성능이 무너진다
- 해당 클래스를 대표하는 입력이 캘리브레이션에 부족
- 입력 전처리(정규화, 리사이즈)가 서빙과 다름
3) 속도 이득이 생각보다 없다
- CPU에서 스레딩 설정, 배치 크기, 연산자 fusion 여부 확인
- 모델이 INT8으로 변환되지 않은 연산이 많을 수 있음
- 실제로는 memory bound라서 INT8 이득이 제한적일 수 있음
체크리스트: 정확도 유지용 PTQ 절차
torch.backends.quantized.engine를 타깃 환경에 맞게 고정- FX Graph Mode로
prepare_fx후 캘리브레이션 수행 - 캘리브레이션 데이터는 “대표성” 중심으로 구성
- 기본 qconfig에서 시작하되, 문제 시 observer를 histogram 기반으로 조정
- 민감 레이어는 부분 양자화로 FP32 유지
- 최종 정확도 외에 로짓 유사도, 레이어별 오차로 원인 좁히기
마무리
PTQ INT8은 “한 번에 완벽히” 되기보다, 캘리브레이션과 관측자 설정을 바꿔가며 정확도 하락 원인을 좁혀가는 반복 작업에 가깝습니다. 하지만 절차를 체계화하면 재학습 없이도 상당한 수준의 정확도를 유지하면서 CPU 서빙 비용과 지연 시간을 줄일 수 있습니다.
다음 단계로는 (1) 부분 양자화 정책을 모델별로 템플릿화하고, (2) 캘리브레이션 데이터 샘플링을 자동화하며, (3) 변환 전후를 CI에서 회귀 테스트로 묶는 것을 추천합니다.