- Published on
PyTorch INT8 양자화 실전 - PTQ vs QAT
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙 비용을 줄이기 위해 INT8 양자화를 고려할 때 가장 먼저 부딪히는 선택이 PTQ(Post-Training Quantization)와 QAT(Quantization Aware Training)입니다. 둘 다 목표는 동일합니다. FP32(또는 FP16/BF16) 모델을 INT8로 바꿔서 지연시간(latency), 메모리, CPU 추론 처리량을 개선하는 것.
하지만 실제 현업에서는 “어떤 모델은 PTQ로 충분히 정확도가 유지되는데, 어떤 모델은 QAT 없이는 정확도가 무너진다” 같은 일이 흔합니다. 이 글에서는 PyTorch 관점에서 PTQ와 QAT를 실전 기준으로 비교하고, 바로 가져다 쓸 수 있는 코드와 디버깅 포인트를 정리합니다.
INT8 양자화가 바꾸는 것: 스케일과 제로포인트
INT8 양자화는 실수 텐서 x를 정수 텐서 q로 근사합니다.
- 대칭(symmetric) 예:
q = round(x / scale) - 비대칭(asymmetric) 예:
q = round(x / scale) + zero_point
PyTorch의 정적(static) 양자화에서 핵심은 아래 2가지입니다.
- Activation(활성화) 범위 추정: 캘리브레이션 데이터로
min/max혹은 히스토그램 기반 범위를 잡음 - Weight(가중치) 양자화: 보통 per-tensor 또는 per-channel로 스케일을 잡음
여기서 Activation 범위를 잘못 잡으면 정확도 손실이 크게 납니다. PTQ와 QAT의 차이는 “이 범위를 어떻게 확보하느냐”에 가깝습니다.
PTQ vs QAT: 언제 무엇을 선택할까
PTQ(Post-Training Quantization)
학습이 끝난 모델을 대상으로 캘리브레이션만 수행하고 INT8로 변환합니다.
- 장점
- 학습 재실행이 거의 필요 없음
- 구현이 상대적으로 간단
- 모델/데이터 접근이 제한된 환경에서도 적용 가능
- 단점
- 분포가 민감한 모델(예: 작은 채널, 큰 아웃라이어, attention 계열)에서 정확도 급락 가능
- 캘리브레이션 데이터 품질/대표성에 성패가 좌우됨
실무 팁: PTQ는 **“일단 빠르게 성능 이득을 확인”**하는 1차 시도로 좋습니다. 특히 CNN 기반 비전 모델이나 비교적 안정적인 MLP 계열은 PTQ로도 만족스러운 경우가 많습니다.
QAT(Quantization Aware Training)
학습 과정에서 양자화 오차를 모사(fake quantization)하고 그 오차를 포함한 상태로 파라미터를 업데이트합니다.
- 장점
- PTQ 대비 정확도 유지 가능성이 높음
- activation outlier, 분포 변화에 더 강함
- 단점
- 학습 파이프라인이 필요(데이터, 시간, 비용)
- 학습 안정화(러닝레이트, 스케줄, freeze 전략) 튜닝이 필요
실무 팁: QAT는 “PTQ로 정확도가 목표치에 못 미칠 때” 또는 “초기부터 INT8이 필수인 제품”에서 선택하는 편이 비용 대비 합리적입니다.
PyTorch에서 INT8 양자화의 큰 그림
PyTorch(특히 torch.ao.quantization)에서 정적 INT8 양자화 흐름은 대체로 아래입니다.
- 모델을
eval()로 전환 qconfig설정(백엔드fbgemm또는qnnpack)prepare로 옵저버(observer) 삽입- 캘리브레이션 데이터로 forward 수행
convert로 실제 INT8 연산 모듈로 변환- 정확도/지연시간 측정
QAT는 위 흐름에서 prepare_qat를 사용하고, 학습 단계에서 fake-quant가 들어간 상태로 fine-tune을 진행합니다.
PTQ 실전 코드: 정적(static) 양자화
아래 예시는 torchvision의 resnet18을 대상으로 하는 전형적인 PTQ 파이프라인입니다. (CPU 추론 기준)
주의: 양자화는 CPU 백엔드에 강하게 의존합니다. 서버 CPU는 보통 fbgemm, 모바일은 qnnpack을 주로 씁니다.
import torch
import torch.ao.quantization as tq
from torchvision.models import resnet18
# 1) 모델 준비
model_fp32 = resnet18(weights=None)
model_fp32.eval()
# 2) 백엔드 선택
# 서버 x86: fbgemm, ARM/모바일: qnnpack
torch.backends.quantized.engine = "fbgemm"
# 3) qconfig 설정
model_fp32.qconfig = tq.get_default_qconfig(torch.backends.quantized.engine)
# 4) prepare: observer 삽입
model_prepared = tq.prepare(model_fp32, inplace=False)
# 5) calibration: 대표 데이터로 forward
# 실제로는 validation subset 등 "대표성 있는" 데이터가 중요
with torch.inference_mode():
for _ in range(32):
x = torch.randn(1, 3, 224, 224)
_ = model_prepared(x)
# 6) convert: INT8 모델로 변환
model_int8 = tq.convert(model_prepared, inplace=False)
# 7) 확인
print(model_int8)
# 추론
with torch.inference_mode():
y = model_int8(torch.randn(1, 3, 224, 224))
print(y.shape)
PTQ에서 자주 터지는 문제 5가지
- 캘리브레이션 데이터가 너무 적음
- 32배치로도 되는 모델이 있지만, 분포가 복잡한 모델은 더 필요합니다.
- 실제 입력 분포와 캘리브레이션 분포 불일치
- 프로덕션 입력이 더 다양한데 validation 일부만 쓰면 activation range가 깨집니다.
- 연산자 미지원(op coverage)
- 일부 커스텀 레이어 또는 특정 activation은 INT8 커널이 없을 수 있습니다.
- 레이어 퓨전(fuse) 누락
Conv + BN + ReLU같은 패턴은 fuse가 성능과 정확도에 영향을 줍니다.
- per-tensor vs per-channel 설정 부적절
- 특히 weight는 per-channel이 유리한 경우가 많습니다.
QAT 실전 코드: FakeQuant로 학습 후 변환
QAT는 학습 단계에서 fake-quant를 삽입합니다. 핵심은
prepare_qat적용- 일정 스텝 이후 observer를 freeze 하거나 fake-quant를 고정
- 학습을 짧게 fine-tune
import torch
import torch.nn as nn
import torch.optim as optim
import torch.ao.quantization as tq
from torchvision.models import resnet18
torch.backends.quantized.engine = "fbgemm"
model = resnet18(weights=None)
model.train()
# QAT 설정
model.qconfig = tq.get_default_qat_qconfig(torch.backends.quantized.engine)
# prepare_qat: fake quant 모듈 삽입
model_qat = tq.prepare_qat(model, inplace=False)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_qat.parameters(), lr=1e-3, momentum=0.9)
# 더미 학습 루프 (실제로는 학습 데이터 사용)
for step in range(200):
x = torch.randn(8, 3, 224, 224)
y = torch.randint(0, 1000, (8,))
optimizer.zero_grad()
out = model_qat(x)
loss = criterion(out, y)
loss.backward()
optimizer.step()
# 실전 팁: 일정 step 이후 observer를 고정하여 안정화
if step == 100:
model_qat.apply(tq.disable_observer)
if step == 150:
model_qat.apply(tq.freeze_bn_stats)
# eval로 전환 후 convert
model_qat.eval()
model_int8 = tq.convert(model_qat, inplace=False)
with torch.inference_mode():
out = model_int8(torch.randn(1, 3, 224, 224))
print(out.shape)
QAT 튜닝 포인트
- 학습률: 원 학습률보다 낮게 시작하는 경우가 많습니다.
- 학습 길이: 풀 트레이닝이 아니라 짧은 fine-tune으로도 개선되는 경우가 많습니다.
- Observer disable 시점: 너무 빨리 끄면 범위가 덜 잡히고, 너무 늦게 끄면 학습이 불안정할 수 있습니다.
- BN 처리: BN stats freeze가 도움이 되는 경우가 많습니다.
성능 측정: “정확도”와 “지연시간”을 같이 봐야 한다
INT8의 목적은 대부분 지연시간/비용 절감입니다. 따라서 정확도만 보고 끝내면 위험합니다.
- 정확도:
top-1,F1, task-specific metric - 성능:
p50/p95 latency,throughput(qps),CPU util,memory
간단한 CPU latency 측정 예시입니다.
import time
import torch
def bench(model, iters=200, warmup=50):
model.eval()
x = torch.randn(1, 3, 224, 224)
with torch.inference_mode():
for _ in range(warmup):
_ = model(x)
t0 = time.time()
for _ in range(iters):
_ = model(x)
t1 = time.time()
return (t1 - t0) * 1000 / iters
# fp32 vs int8 비교
# print("fp32 ms:", bench(model_fp32))
# print("int8 ms:", bench(model_int8))
실전에서는 스레드 수(torch.set_num_threads), 배치 크기, 입력 크기, NUMA, 컨테이너 CPU quota에 따라 결과가 크게 바뀝니다. 운영 환경과 최대한 비슷한 조건에서 재야 합니다.
PTQ가 실패하는 대표 케이스와 QAT로 넘어가는 기준
다음 중 하나라도 해당하면 QAT를 검토할 가치가 큽니다.
- PTQ 적용 후 정확도 하락이 SLA를 초과
- 입력 분포가 시간에 따라 변동(예: 광고/추천, 사용자 생성 콘텐츠)
- 모델이 attention 기반이며 activation outlier가 큼
- 캘리브레이션 데이터를 충분히 확보하기 어렵거나 대표성이 낮음
반대로 아래라면 PTQ로 끝낼 가능성이 높습니다.
- CNN 계열, activation 분포가 비교적 안정적
- 캘리브레이션 데이터를 넉넉히 확보 가능
- 약간의 정확도 손실이 비용 절감 대비 허용 가능
배포 체크리스트: “변환 성공”과 “운영 안전”은 다르다
양자화 모델이 로컬에서 돌아가는 것과 운영에서 문제 없이 도는 것은 별개입니다.
- 모델 저장/로딩
state_dict만 저장할지,torch.jit.trace/script로 패키징할지 결정
- CPU 백엔드 일치
- 개발 머신과 운영 머신이 같은 엔진(
fbgemm)을 쓰는지 확인
- 개발 머신과 운영 머신이 같은 엔진(
- 폴백(fallback) 전략
- 특정 연산이 INT8 미지원이면 FP32로 폴백되며 성능 이점이 사라질 수 있음
- 입력 전처리 일관성
- 캘리브레이션과 운영 입력 스케일/정규화가 다르면 activation range가 틀어질 수 있음
운영 장애 관점에서 보면 “한 번에 크게 바꾸지 말고 점진적으로 롤아웃하고, 재시도/백오프 같은 보호 장치”가 중요합니다. API 트래픽 제어와 재시도 설계가 필요하다면 OpenAI 429/Rate Limit 대응 - 재시도·백오프·큐잉 글의 패턴도 서버 운영에 그대로 응용할 수 있습니다.
결론: 실전 선택 가이드
- 빠른 승부: PTQ부터
- 캘리브레이션 데이터만 잘 준비해도 꽤 많은 모델이 통과합니다.
- 정확도가 핵심: QAT로 보강
- PTQ에서 정확도가 무너지면 QAT가 가장 확실한 해법인 경우가 많습니다.
- 성공의 80퍼센트는 데이터
- PTQ는 캘리브레이션 대표성이, QAT는 fine-tune 데이터 품질이 성패를 좌우합니다.
다음 단계로는 모델 구조(Conv/Transformer), 목표 하드웨어(CPU/모바일), 그리고 PyTorch 버전에 따라 FX Graph Mode Quantization이나 torch.compile과의 조합까지 고려하게 됩니다. 하지만 대부분의 제품에서는 이 글의 PTQ/QAT 파이프라인만 제대로 정리해도 INT8 전환의 첫 프로덕션 릴리스까지 충분히 도달할 수 있습니다.