- Published on
PyTorch 2.0 PTQ로 INT8 양자화 정확도 지키기
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙 환경에서 모델을 더 빠르고 저렴하게 돌리려면 INT8 양자화가 매력적입니다. 다만 “속도는 빨라졌는데 정확도가 무너졌다”는 경험도 흔합니다. 특히 PTQ(Post-Training Quantization)는 학습을 다시 하지 않기 때문에, 캘리브레이션과 관측자(observer) 선택이 곧 정확도와 직결됩니다.
이 글에서는 PyTorch 2.0 기준으로 Eager 모드의 PTQ 흐름을 중심으로, INT8 양자화에서 정확도를 지키기 위한 실전 포인트를 정리합니다. (PyTorch의 FX Graph Mode나 torch.compile 기반 흐름도 있지만, PTQ 정확도 이슈의 핵심은 동일합니다.)
또한 로컬 LLM에서 메모리 압박을 줄이는 방향이 궁금하다면 Transformers 로컬 LLM OOM 해결 - 4bit+KV 캐시도 함께 보면, “정확도 vs 비용” 트레이드오프를 더 입체적으로 잡을 수 있습니다.
PTQ에서 정확도가 깨지는 대표 원인
INT8 PTQ는 대략 다음 과정을 거칩니다.
- FP32 모델 준비
- 관측자 삽입(activation 통계 수집)
- 캘리브레이션 데이터로 forward 수행
- 통계를 바탕으로 scale, zero-point 결정
- INT8로 변환(convert)
정확도 하락은 주로 아래에서 발생합니다.
- 캘리브레이션 데이터가 실제 트래픽 분포를 못 따라감
- activation outlier가 많아 MinMax 기반 scale이 깨짐
- 레이어별 민감도 차이: 어떤 레이어는 INT8로 내리면 손실이 급격함
- 연산 패턴 미스매치: Conv/Linear는 INT8 최적화가 잘 되지만, 일부 연산은 dequant-quant가 반복되며 손실과 지연이 증가
- per-tensor vs per-channel 선택 실패: 특히 weight는 per-channel이 정확도에 유리한 경우가 많음
사전 준비: 백엔드와 qconfig 이해하기
PyTorch 양자화는 CPU 백엔드(예: FBGEMM, QNNPACK)에 따라 지원 연산과 성능이 달라집니다.
- 서버 x86 CPU: 보통
fbgemm - 모바일 ARM: 보통
qnnpack
다음은 기본 설정 예시입니다.
import torch
import torch.ao.quantization as tq
torch.backends.quantized.engine = "fbgemm" # x86 서버 기준
그리고 PTQ에서 가장 중요한 설정이 qconfig입니다. qconfig는 크게 두 가지를 정합니다.
- activation 관측자(예: MinMaxObserver, HistogramObserver)
- weight 관측자(보통 per-channel 지원 여부가 중요)
실전 1: 기본 PTQ 파이프라인(Conv/Linear 중심)
아래는 Eager 모드에서 흔히 쓰는 PTQ 파이프라인입니다. 예시는 단순화를 위해 Conv2d나 Linear 위주 모델을 가정합니다.
import copy
import torch
import torch.nn as nn
import torch.ao.quantization as tq
class SmallNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, stride=1, padding=1)
self.relu = nn.ReLU()
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(16, 10)
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
fp32_model = SmallNet().eval()
# 1) qconfig 설정
qconfig = tq.get_default_qconfig("fbgemm")
# 2) 복사본에 qconfig 적용
model_to_quantize = copy.deepcopy(fp32_model)
model_to_quantize.qconfig = qconfig
# 3) 관측자 삽입
prepared = tq.prepare(model_to_quantize, inplace=False)
# 4) 캘리브레이션
@torch.no_grad()
def calibrate(model, data_loader, num_batches=32):
model.eval()
for i, (x, _) in enumerate(data_loader):
model(x)
if i + 1 >= num_batches:
break
# calibrate(prepared, calib_loader)
# 5) INT8 변환
quantized = tq.convert(prepared, inplace=False)
이 단계까지는 “동작하는 INT8 모델”을 만드는 과정입니다. 하지만 정확도를 지키려면 여기서부터가 진짜 시작입니다.
실전 2: 캘리브레이션 데이터가 정확도의 70%를 결정한다
PTQ에서 캘리브레이션은 학습이 아니라 “통계 수집”입니다. 그래서 다음 원칙이 중요합니다.
1) 실제 서빙 입력 분포를 최대한 반영
- 이미지라면 전처리(리사이즈, 정규화)까지 포함해 동일하게
- 텍스트라면 토크나이저, max length, padding 전략을 동일하게
2) 배치 수보다 “분포 커버리지”
- 1만 장을 무작정 넣기보다, 클래스/도메인/난이도/밝기/길이 등 분포를 고르게
- outlier가 잦은 서비스라면 outlier도 포함하되, 비율을 현실적으로
3) 캘리브레이션을 너무 적게 하면
- activation 범위 추정이 불안정해 scale이 흔들리고, 특정 입력에서 saturation이 발생
4) 너무 많이 하면
- 시간만 늘고 개선 폭이 줄어드는 구간이 빨리 옵니다
실무에서는 보통 “수십~수백 배치”로 먼저 스윕하고, 정확도 하락이 큰 경우에만 캘리브레이션 셋을 재구성하는 편이 효율적입니다.
실전 3: MinMax 대신 Histogram(또는 MovingAverage)로 outlier 완화
activation outlier가 있는 모델은 MinMaxObserver가 극단값에 끌려가서 대부분의 값이 좁은 구간에 뭉개지는 문제가 생깁니다. 이때 HistogramObserver가 도움이 됩니다.
import torch.ao.quantization as tq
act_observer = tq.HistogramObserver.with_args(
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=True,
)
weight_observer = tq.PerChannelMinMaxObserver.with_args(
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
)
qconfig = tq.QConfig(activation=act_observer, weight=weight_observer)
포인트는 다음입니다.
- weight는 가능하면 per-channel symmetric을 우선 고려
- activation은 모델 특성에 따라 histogram 계열을 시험
reduce_range=True는 일부 백엔드에서 정확도/호환성에 영향을 줄 수 있어 A/B로 확인
실전 4: 레이어별로 “양자화 제외” 또는 “관측자 교체” 하기
모든 레이어를 INT8로 내리는 것이 최선은 아닙니다. 특히 다음은 민감한 경우가 많습니다.
- 출력단
Linear(클래스 로짓이 작은 차이에 민감) - attention 계열(특히 softmax 주변)
- LayerNorm / 일부 normalization
Eager 모드에서는 모듈 단위로 qconfig = None을 줘서 제외할 수 있습니다.
import copy
import torch.ao.quantization as tq
model = copy.deepcopy(fp32_model).eval()
model.qconfig = tq.get_default_qconfig("fbgemm")
# 예: 마지막 fc는 FP32로 유지
model.fc.qconfig = None
prepared = tq.prepare(model, inplace=False)
# calibrate(prepared, calib_loader)
quantized = tq.convert(prepared, inplace=False)
정확도 하락이 큰데 속도 이득이 충분하다면, “대부분 INT8 + 민감 레이어 FP32” 하이브리드가 실전에서 자주 이깁니다.
실전 5: Fusion(Conv+ReLU 등)으로 정확도와 성능을 동시에
연산을 fuse하면 quantize/dequantize 경계가 줄어들어 정확도와 성능이 동시에 좋아질 수 있습니다. 대표적으로 Conv + ReLU는 fuse 후보입니다.
import torch.ao.quantization as tq
class FuseNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
self.relu = nn.ReLU()
self.fc = nn.Linear(16, 10)
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = x.mean(dim=(2, 3))
return self.fc(x)
model = FuseNet().eval()
# fuse_modules는 모듈 이름 리스트로 지정
fused = tq.fuse_modules(model, [["conv", "relu"]], inplace=False)
fused.qconfig = tq.get_default_qconfig("fbgemm")
prepared = tq.prepare(fused, inplace=False)
# calibrate(prepared, calib_loader)
quantized = tq.convert(prepared, inplace=False)
주의할 점은, fuse 가능한 패턴이 모델 구조와 PyTorch 버전에 따라 다르므로, fuse 후 수치가 달라지는지(특히 BN 포함 시) 검증이 필요합니다.
실전 6: 정확도 검증은 “전체 스코어” 말고 구간별로 쪼개기
INT8 PTQ의 정확도 하락은 특정 입력 구간에서만 터지는 경우가 많습니다.
- 밝은 이미지/어두운 이미지
- 긴 문장/짧은 문장
- 특정 클래스 또는 특정 도메인
따라서 단일 Top-1이나 F1만 보지 말고, 최소한 아래를 분리해서 보세요.
- 클래스별 confusion
- 입력 길이/밝기/스케일 버킷별 정확도
- outlier 케이스 리그레션 셋
이 방식은 성능 튜닝에서도 동일하게 유효합니다. 예를 들어 네트워크가 간헐적으로 흔들릴 때 전체 평균만 보면 원인이 가려집니다. 그런 문제는 EKS Pod egress 간헐 끊김 - SNAT·NAT GW 추적법처럼 “구간을 쪼개 추적”해야 잡히는 것과 결이 같습니다.
실전 7: 측정할 때는 반드시 warm-up과 스레드 고정
INT8이 빨라졌는지 판단하려면 측정이 흔들리면 안 됩니다.
- warm-up 실행
- CPU 스레드 수 고정
- 같은 배치 크기, 같은 입력 shape
import time
import torch
def benchmark(model, input_tensor, iters=200, warmup=20):
model.eval()
with torch.no_grad():
for _ in range(warmup):
_ = model(input_tensor)
t0 = time.time()
for _ in range(iters):
_ = model(input_tensor)
t1 = time.time()
return (t1 - t0) / iters
# 예: 단일 스레드로 고정(환경에 따라 조정)
# torch.set_num_threads(1)
x = torch.randn(1, 3, 224, 224)
# fp32_time = benchmark(fp32_model, x)
# int8_time = benchmark(quantized, x)
정확도도 마찬가지로, 평가 시 전처리/후처리/threshold를 동일하게 맞추지 않으면 “양자화 탓”으로 오판하기 쉽습니다.
트러블슈팅 체크리스트(정확도 하락이 클 때)
- 캘리브레이션 셋을 교체: 실제 트래픽에서 샘플링했는가
- activation observer를 histogram 계열로 변경
- weight를 per-channel로 강제: 가능하면 symmetric
- 민감 레이어 제외: 마지막 FC, softmax 주변, normalization
- fuse 적용: Conv+ReLU, Conv+BN+ReLU(가능한 패턴 확인)
- 입력 shape 고정: 동적 shape가 많으면 통계가 섞여 악화 가능
- 스코어를 구간별로 분석: 특정 케이스만 무너지는지 확인
마무리: “INT8로 바꾼다”가 아니라 “분포를 이식한다”
PTQ는 학습을 다시 하지 않는 대신, FP32 모델이 보던 입력 분포와 activation 분포를 INT8의 제한된 표현 범위에 최대한 손실 없이 옮기는 작업입니다. 그래서 캘리브레이션 데이터와 관측자 설정이 곧 모델의 “새로운 수치 세계관”을 결정합니다.
정리하면, PyTorch 2.0에서 INT8 PTQ 정확도를 지키는 가장 현실적인 우선순위는 다음입니다.
- 캘리브레이션 데이터 품질을 먼저 올리고
- activation observer를 outlier 친화적으로 바꾸고
- weight는 per-channel을 우선 적용하며
- 민감 레이어는 과감히 FP32로 남기는 하이브리드를 고려
이 4가지만 제대로 해도, “속도는 얻고 정확도는 지키는” PTQ 성공 확률이 크게 올라갑니다.