- Published on
PyTorch PTQ/QAT로 INT8 양자화 - 정확도 유지
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙 환경에서 모델을 더 빠르고 가볍게 만들고 싶을 때, 가장 먼저 검토하는 카드가 INT8 양자화입니다. FP32 대비 메모리 사용량을 줄이고(가중치 기준 약 4분의 1), CPU 추론에서 특히 유의미한 속도 향상을 기대할 수 있습니다. 문제는 정확도입니다. 같은 INT8이라도 PTQ(Post-Training Quantization) 로 끝낼지, QAT(Quantization-Aware Training) 까지 갈지에 따라 정확도 손실과 개발 비용이 크게 달라집니다.
이 글은 PyTorch(특히 torch.ao.quantization)에서 PTQ/QAT를 적용해 정확도 하락을 최소화하는 방법을, 실수하기 쉬운 포인트(캘리브레이션, 관측자, fusion, per-channel, backend 설정) 중심으로 정리합니다.
또한 운영 관점에서 “성능 최적화는 결국 병목을 줄이는 일”이라는 점에서, 재시도 폭주/타임아웃 설계처럼 시스템 레벨에서의 성능 안정화도 함께 고려해야 합니다. 관련해서는 gRPC MSA에서 데드라인·리트라이 폭주 막는 법도 함께 참고하면 좋습니다.
INT8 양자화 기본: 무엇이 바뀌나
INT8 양자화는 크게 두 가지를 바꿉니다.
- 가중치(Weight) 양자화: FP32 가중치를 INT8로 저장
- 활성값(Activation) 양자화: 레이어 출력(activation)을 INT8로 표현
양자화는 대개 다음 형태로 표현됩니다.
x_int8 = clamp(round(x_fp32 / scale) + zero_point)x_fp32 ≈ (x_int8 - zero_point) * scale
여기서 핵심은 scale과 zero_point를 어떻게 잘 잡느냐입니다. 이 값을 잘못 잡으면 clipping/rounding 오차가 커져 정확도가 떨어집니다.
PTQ vs QAT 선택 기준
- PTQ: 학습 없이(또는 최소한의 튜닝으로) 양자화. 빠르고 싸지만 정확도 손실이 날 수 있음.
- QAT: 학습 중에 양자화 오차를 “보게” 만들어 모델이 적응하도록 함. 정확도는 유리하지만 학습 비용이 듦.
실무적으로는 다음 순서를 권합니다.
- Dynamic Quantization(가능하면)으로 빠르게 이득 확인
- Static PTQ + 캘리브레이션/옵저버 튜닝
- 그래도 정확도 안 나오면 QAT
PyTorch 양자화 스택: torch.ao.quantization 개요
PyTorch는 과거 torch.quantization에서 현재 torch.ao.quantization(AO: Architecture Optimization)로 정리되었습니다. CPU 백엔드는 보통 다음 중 하나를 사용합니다.
fbgemm: x86 서버 CPU에서 주로 사용qnnpack: ARM/모바일 계열에서 주로 사용
백엔드에 따라 지원되는 연산과 성능이 달라서, 개발 머신과 배포 머신의 CPU 아키텍처가 다르면 결과가 달라질 수 있습니다.
PTQ(Static) 실전: 정확도 유지의 핵심은 캘리브레이션
Static PTQ는 활성값까지 INT8로 만들기 때문에, 캘리브레이션 데이터로 activation 분포를 잘 관측하는 게 승부처입니다.
1) 모듈 fusion 먼저
Conv-BN-ReLU 같은 패턴은 fusion을 하면 수치적으로도 유리하고(특히 BN folding), INT8 커널 매칭도 좋아집니다.
아래는 전형적인 CNN에서의 흐름 예시입니다.
import torch
import torch.nn as nn
import torch.ao.quantization as quant
class SmallCNN(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, 3, stride=2, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.Conv2d(16, 32, 3, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, 10)
def forward(self, x):
x = self.features(x)
x = self.pool(x).flatten(1)
return self.fc(x)
model_fp32 = SmallCNN().eval()
# fusion: Conv+BN+ReLU 패턴을 묶음
# Sequential 내부 인덱스를 정확히 지정해야 함
model_fp32.fuse_model = lambda: torch.ao.quantization.fuse_modules(
model_fp32,
[
["features.0", "features.1", "features.2"],
["features.3", "features.4", "features.5"],
],
inplace=True,
)
model_fp32.fuse_model()
fusion 인덱스가 틀리면 조용히 실패하거나(혹은 일부만 적용) 성능/정확도 모두 손해를 봅니다. 모델 구조가 복잡하면 print(model)로 모듈 경로를 먼저 확정하세요.
2) qconfig 설정: per-channel weight는 거의 필수
정확도 유지에 가장 큰 영향을 주는 설정 중 하나가 가중치 per-channel 양자화입니다. Conv/Linear weight를 채널별로 다른 scale로 양자화하면 오차가 확 줄어드는 경우가 많습니다.
import torch
import torch.ao.quantization as quant
# x86 서버라면 보통 fbgemm
torch.backends.quantized.engine = "fbgemm"
model_fp32.qconfig = quant.get_default_qconfig("fbgemm")
# 참고: 더 공격적인 설정(예: activation histogram observer)은 상황에 따라 조정
# model_fp32.qconfig = quant.QConfig(
# activation=quant.HistogramObserver.with_args(dtype=torch.quint8),
# weight=quant.default_per_channel_weight_observer
# )
get_default_qconfig는 “대체로 무난한” 선택입니다. 정확도가 모자라면 activation 옵저버를 HistogramObserver로 바꿔 clipping을 완화하거나, 캘리브레이션 샘플 수를 늘리는 쪽을 먼저 시도하는 편이 안전합니다.
3) prepare 후 캘리브레이션 수행
prepare는 옵저버를 삽입하고, 캘리브레이션 동안 통계를 모읍니다.
import torch
import torch.ao.quantization as quant
model_prepared = quant.prepare(model_fp32, inplace=False)
# 캘리브레이션: 실제 서빙 입력 분포를 반영한 데이터로 돌려야 함
# 예시는 더미 데이터
with torch.inference_mode():
for _ in range(200):
x = torch.randn(32, 3, 224, 224)
_ = model_prepared(x)
캘리브레이션 데이터는 “학습 데이터 아무거나”가 아니라, 서빙 트래픽과 유사한 분포가 중요합니다.
- 전처리(정규화/리사이즈/크롭)까지 동일해야 함
- 야간/저조도/특정 카메라 등 실제 변동성을 포함해야 함
- 배치 크기는 큰 의미가 없지만, 샘플 다양성은 중요
4) convert로 INT8 모델 생성
model_int8 = quant.convert(model_prepared, inplace=False)
# 추론
with torch.inference_mode():
y = model_int8(torch.randn(1, 3, 224, 224))
이제 model_int8는 양자화된 모듈(예: QuantizedConv2d)을 포함합니다.
PTQ에서 정확도 떨어질 때 체크리스트
- fusion 누락: Conv-BN-ReLU 미융합은 정확도/성능 모두 손해
- 캘리브레이션 부족: 샘플 수가 적거나 분포가 다르면 activation scale이 망가짐
- outlier: 극단값이 activation range를 넓혀 정밀도가 떨어짐
- 레이어 민감도: 첫 Conv, 마지막 FC, attention 계열은 특히 민감할 수 있음
- 연산 미지원 fallback: 일부 연산이 FP32로 남아 경계에서 오차가 커지기도 함
PTQ(Dynamic)로 “안전한” 첫 이득 보기
Transformer나 RNN 계열에서 Linear가 대부분일 때는 dynamic quantization이 빠르고 안정적입니다. activation을 런타임에 동적으로 스케일링하고, 주로 weight를 INT8로 바꿉니다.
import torch
import torch.nn as nn
model_fp32 = nn.Sequential(
nn.Linear(768, 768),
nn.ReLU(),
nn.Linear(768, 2),
).eval()
model_int8_dyn = torch.ao.quantization.quantize_dynamic(
model_fp32,
{nn.Linear},
dtype=torch.qint8,
)
정확도 손실이 상대적으로 적고 적용이 쉬워서, “PTQ로 될까?”를 빠르게 확인하는 용도로 좋습니다.
QAT 실전: 정확도를 지키는 정공법
PTQ로 정확도가 충분히 나오지 않는다면 QAT가 답입니다. QAT는 학습 중 forward에 fake quantization을 삽입해, 모델이 양자화 오차에 적응하도록 만듭니다.
QAT의 핵심 포인트
- 학습률을 낮추고(특히 후반) 짧게 파인튜닝하는 경우가 많음
- BatchNorm을 어떻게 다룰지 중요(동결/폴딩 타이밍)
- 학습 데이터가 서빙 분포를 대표해야 함(PTQ보다 더 중요)
QAT 코드 예시(간단 파이프라인)
import torch
import torch.nn as nn
import torch.ao.quantization as quant
torch.backends.quantized.engine = "fbgemm"
model = SmallCNN()
model.train()
# 1) fusion
model.fuse_model()
# 2) QAT qconfig
model.qconfig = quant.get_default_qat_qconfig("fbgemm")
# 3) prepare_qat
model_qat = quant.prepare_qat(model, inplace=False)
optimizer = torch.optim.SGD(model_qat.parameters(), lr=1e-4, momentum=0.9)
criterion = nn.CrossEntropyLoss()
# 예시 학습 루프(더미 데이터)
for step in range(300):
x = torch.randn(32, 3, 224, 224)
t = torch.randint(0, 10, (32,))
optimizer.zero_grad(set_to_none=True)
y = model_qat(x)
loss = criterion(y, t)
loss.backward()
optimizer.step()
# 4) eval로 전환 후 convert
model_qat.eval()
model_int8 = quant.convert(model_qat, inplace=False)
실제 프로젝트에서는 더미 데이터 대신 실데이터로, 학습 스텝도 더 길게 잡습니다. 다만 QAT는 “처음부터 재학습”이 아니라, FP32 체크포인트에서 짧게 파인튜닝하는 방식이 비용 대비 효과가 좋습니다.
QAT에서 정확도 유지 팁
- 첫/마지막 레이어는 양자화 제외를 고려(민감도 높음)
- activation 옵저버를 histogram 기반으로 바꾸어 clipping 완화
- per-channel weight는 유지
- 학습 후반에 BN 통계를 안정화(필요시 BN freeze)
“일부만” 양자화하기: 민감 레이어 보호
정확도가 특정 레이어에서 크게 깨질 때는, 전체 INT8 고집 대신 부분 양자화가 실전적으로 더 낫습니다.
전략 예시:
- backbone은 INT8, head는 FP16/FP32
- 첫 Conv와 마지막 Linear는 FP32 유지
- attention 블록은 FP16 유지, FFN Linear만 INT8
PyTorch에서는 모듈별로 qconfig = None을 주어 제외하는 패턴이 흔합니다(모델 구조에 맞게 적용).
# 예: 마지막 fc는 양자화 제외
model_fp32.fc.qconfig = None
이런 트레이드오프는 “정확도 목표”와 “지연시간 목표”를 함께 놓고 결정해야 합니다. 운영에서 지연시간 SLO를 맞추려면, 모델만 빠르게 만드는 것 외에 타임아웃/재시도 정책도 함께 봐야 합니다. 이 관점은 Go gRPC DEADLINE_EXCEEDED 원인과 재시도·타임아웃 설계와도 연결됩니다.
성능 측정: 속도는 반드시 엔드투엔드로 재기
양자화는 커널 속도만 빨라져도, 전처리/후처리/데이터 복사에서 병목이 남으면 체감이 약합니다. 모델 단독 벤치마크와 함께, 실제 서빙 핫패스에서 측정하세요.
간단한 마이크로 벤치 예시:
import time
import torch
def bench(model, iters=200, warmup=50):
x = torch.randn(1, 3, 224, 224)
model.eval()
with torch.inference_mode():
for _ in range(warmup):
_ = model(x)
t0 = time.perf_counter()
for _ in range(iters):
_ = model(x)
t1 = time.perf_counter()
return (t1 - t0) / iters
# fp32_time = bench(model_fp32)
# int8_time = bench(model_int8)
# print(fp32_time, int8_time)
주의할 점:
- CPU 스레드 수(
torch.set_num_threads)에 따라 결과가 크게 달라짐 - 같은 머신에서 비교해야 함
- 배치 크기와 입력 크기를 실제 트래픽에 맞춰야 함
흔한 함정: 정확도는 맞는데 운영에서 깨지는 케이스
1) 캘리브레이션/학습과 서빙 전처리가 다름
가장 흔합니다. 정규화 상수, 리사이즈 방식, 컬러 채널 순서가 다르면 activation 분포가 바뀌어 INT8에서 특히 치명적입니다.
2) 모델을 train() 상태로 서빙
옵저버/BN/드롭아웃 등으로 인해 결과가 흔들립니다. 서빙 직전에는 항상 eval()과 inference_mode()를 강제하세요.
3) 연산 미지원으로 부분 FP32 fallback
겉으로는 돌아가는데 성능이 안 나오거나, 경계에서 오차가 커질 수 있습니다. 변환 후 모델 그래프를 확인하고, 어떤 모듈이 quantized로 바뀌었는지 점검하세요.
4) 지연시간은 줄었는데 재시도 폭주로 비용 증가
p99가 줄면 타임아웃을 공격적으로 줄이고 싶어지지만, 분산 환경에서는 작은 흔들림이 재시도 폭주로 이어질 수 있습니다. 모델 최적화 후에는 반드시 데드라인/리트라이 정책을 재검토하세요. 이 주제는 gRPC MSA에서 데드라인·리트라이 폭주 막는 법에서 더 깊게 다룹니다.
권장 워크플로우 요약
- 목표 정의: 정확도 하락 허용 범위, p50/p99 지연시간, CPU/메모리 예산
- Dynamic quantization으로 빠른 PoC
- Static PTQ:
- fusion
get_default_qconfig에서 시작- 캘리브레이션 데이터 품질/다양성 확보
- 부족하면 QAT:
- 짧은 파인튜닝
- 민감 레이어 부분 양자화 고려
- 엔드투엔드 벤치마크 + 운영 정책(타임아웃/리트라이) 재정렬
마무리
INT8 양자화는 “적용하면 끝”이 아니라, 분포(캘리브레이션)와 민감도(레이어별 오차) 를 다루는 작업입니다. PTQ는 캘리브레이션과 옵저버 설정이 정확도의 대부분을 좌우하고, QAT는 비용이 들지만 정확도를 지키는 가장 확실한 방법입니다.
다음 액션으로는, 현재 모델에서 PTQ를 적용한 뒤 레이어별 민감도를 측정해(예: 특정 블록만 FP32로 남겨서 비교) “어디서 깨지는지”를 먼저 찾는 것을 권합니다. 그 결과를 바탕으로 PTQ 튜닝으로 끝낼지, QAT 파인튜닝으로 갈지 판단하면 시행착오를 크게 줄일 수 있습니다.