Published on

PyTorch INT8 양자화 실전 - PTQ vs QAT

Authors

서빙 비용을 줄이기 위해 INT8 양자화를 고려할 때 가장 먼저 부딪히는 선택이 PTQ(Post-Training Quantization)QAT(Quantization Aware Training)입니다. 둘 다 목표는 동일합니다. FP32(또는 FP16/BF16) 모델을 INT8로 바꿔서 지연시간(latency), 메모리, CPU 추론 처리량을 개선하는 것.

하지만 실제 현업에서는 “어떤 모델은 PTQ로 충분히 정확도가 유지되는데, 어떤 모델은 QAT 없이는 정확도가 무너진다” 같은 일이 흔합니다. 이 글에서는 PyTorch 관점에서 PTQ와 QAT를 실전 기준으로 비교하고, 바로 가져다 쓸 수 있는 코드와 디버깅 포인트를 정리합니다.

INT8 양자화가 바꾸는 것: 스케일과 제로포인트

INT8 양자화는 실수 텐서 x를 정수 텐서 q로 근사합니다.

  • 대칭(symmetric) 예: q = round(x / scale)
  • 비대칭(asymmetric) 예: q = round(x / scale) + zero_point

PyTorch의 정적(static) 양자화에서 핵심은 아래 2가지입니다.

  1. Activation(활성화) 범위 추정: 캘리브레이션 데이터로 min/max 혹은 히스토그램 기반 범위를 잡음
  2. Weight(가중치) 양자화: 보통 per-tensor 또는 per-channel로 스케일을 잡음

여기서 Activation 범위를 잘못 잡으면 정확도 손실이 크게 납니다. PTQ와 QAT의 차이는 “이 범위를 어떻게 확보하느냐”에 가깝습니다.

PTQ vs QAT: 언제 무엇을 선택할까

PTQ(Post-Training Quantization)

학습이 끝난 모델을 대상으로 캘리브레이션만 수행하고 INT8로 변환합니다.

  • 장점
    • 학습 재실행이 거의 필요 없음
    • 구현이 상대적으로 간단
    • 모델/데이터 접근이 제한된 환경에서도 적용 가능
  • 단점
    • 분포가 민감한 모델(예: 작은 채널, 큰 아웃라이어, attention 계열)에서 정확도 급락 가능
    • 캘리브레이션 데이터 품질/대표성에 성패가 좌우됨

실무 팁: PTQ는 **“일단 빠르게 성능 이득을 확인”**하는 1차 시도로 좋습니다. 특히 CNN 기반 비전 모델이나 비교적 안정적인 MLP 계열은 PTQ로도 만족스러운 경우가 많습니다.

QAT(Quantization Aware Training)

학습 과정에서 양자화 오차를 모사(fake quantization)하고 그 오차를 포함한 상태로 파라미터를 업데이트합니다.

  • 장점
    • PTQ 대비 정확도 유지 가능성이 높음
    • activation outlier, 분포 변화에 더 강함
  • 단점
    • 학습 파이프라인이 필요(데이터, 시간, 비용)
    • 학습 안정화(러닝레이트, 스케줄, freeze 전략) 튜닝이 필요

실무 팁: QAT는 “PTQ로 정확도가 목표치에 못 미칠 때” 또는 “초기부터 INT8이 필수인 제품”에서 선택하는 편이 비용 대비 합리적입니다.

PyTorch에서 INT8 양자화의 큰 그림

PyTorch(특히 torch.ao.quantization)에서 정적 INT8 양자화 흐름은 대체로 아래입니다.

  1. 모델을 eval()로 전환
  2. qconfig 설정(백엔드 fbgemm 또는 qnnpack)
  3. prepare로 옵저버(observer) 삽입
  4. 캘리브레이션 데이터로 forward 수행
  5. convert로 실제 INT8 연산 모듈로 변환
  6. 정확도/지연시간 측정

QAT는 위 흐름에서 prepare_qat를 사용하고, 학습 단계에서 fake-quant가 들어간 상태로 fine-tune을 진행합니다.

PTQ 실전 코드: 정적(static) 양자화

아래 예시는 torchvisionresnet18을 대상으로 하는 전형적인 PTQ 파이프라인입니다. (CPU 추론 기준)

주의: 양자화는 CPU 백엔드에 강하게 의존합니다. 서버 CPU는 보통 fbgemm, 모바일은 qnnpack을 주로 씁니다.

import torch
import torch.ao.quantization as tq
from torchvision.models import resnet18

# 1) 모델 준비
model_fp32 = resnet18(weights=None)
model_fp32.eval()

# 2) 백엔드 선택
# 서버 x86: fbgemm, ARM/모바일: qnnpack
torch.backends.quantized.engine = "fbgemm"

# 3) qconfig 설정
model_fp32.qconfig = tq.get_default_qconfig(torch.backends.quantized.engine)

# 4) prepare: observer 삽입
model_prepared = tq.prepare(model_fp32, inplace=False)

# 5) calibration: 대표 데이터로 forward
# 실제로는 validation subset 등 "대표성 있는" 데이터가 중요
with torch.inference_mode():
    for _ in range(32):
        x = torch.randn(1, 3, 224, 224)
        _ = model_prepared(x)

# 6) convert: INT8 모델로 변환
model_int8 = tq.convert(model_prepared, inplace=False)

# 7) 확인
print(model_int8)

# 추론
with torch.inference_mode():
    y = model_int8(torch.randn(1, 3, 224, 224))
    print(y.shape)

PTQ에서 자주 터지는 문제 5가지

  1. 캘리브레이션 데이터가 너무 적음
    • 32배치로도 되는 모델이 있지만, 분포가 복잡한 모델은 더 필요합니다.
  2. 실제 입력 분포와 캘리브레이션 분포 불일치
    • 프로덕션 입력이 더 다양한데 validation 일부만 쓰면 activation range가 깨집니다.
  3. 연산자 미지원(op coverage)
    • 일부 커스텀 레이어 또는 특정 activation은 INT8 커널이 없을 수 있습니다.
  4. 레이어 퓨전(fuse) 누락
    • Conv + BN + ReLU 같은 패턴은 fuse가 성능과 정확도에 영향을 줍니다.
  5. per-tensor vs per-channel 설정 부적절
    • 특히 weight는 per-channel이 유리한 경우가 많습니다.

QAT 실전 코드: FakeQuant로 학습 후 변환

QAT는 학습 단계에서 fake-quant를 삽입합니다. 핵심은

  • prepare_qat 적용
  • 일정 스텝 이후 observer를 freeze 하거나 fake-quant를 고정
  • 학습을 짧게 fine-tune
import torch
import torch.nn as nn
import torch.optim as optim
import torch.ao.quantization as tq
from torchvision.models import resnet18

torch.backends.quantized.engine = "fbgemm"

model = resnet18(weights=None)
model.train()

# QAT 설정
model.qconfig = tq.get_default_qat_qconfig(torch.backends.quantized.engine)

# prepare_qat: fake quant 모듈 삽입
model_qat = tq.prepare_qat(model, inplace=False)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_qat.parameters(), lr=1e-3, momentum=0.9)

# 더미 학습 루프 (실제로는 학습 데이터 사용)
for step in range(200):
    x = torch.randn(8, 3, 224, 224)
    y = torch.randint(0, 1000, (8,))

    optimizer.zero_grad()
    out = model_qat(x)
    loss = criterion(out, y)
    loss.backward()
    optimizer.step()

    # 실전 팁: 일정 step 이후 observer를 고정하여 안정화
    if step == 100:
        model_qat.apply(tq.disable_observer)
    if step == 150:
        model_qat.apply(tq.freeze_bn_stats)

# eval로 전환 후 convert
model_qat.eval()
model_int8 = tq.convert(model_qat, inplace=False)

with torch.inference_mode():
    out = model_int8(torch.randn(1, 3, 224, 224))
    print(out.shape)

QAT 튜닝 포인트

  • 학습률: 원 학습률보다 낮게 시작하는 경우가 많습니다.
  • 학습 길이: 풀 트레이닝이 아니라 짧은 fine-tune으로도 개선되는 경우가 많습니다.
  • Observer disable 시점: 너무 빨리 끄면 범위가 덜 잡히고, 너무 늦게 끄면 학습이 불안정할 수 있습니다.
  • BN 처리: BN stats freeze가 도움이 되는 경우가 많습니다.

성능 측정: “정확도”와 “지연시간”을 같이 봐야 한다

INT8의 목적은 대부분 지연시간/비용 절감입니다. 따라서 정확도만 보고 끝내면 위험합니다.

  • 정확도: top-1, F1, task-specific metric
  • 성능: p50/p95 latency, throughput(qps), CPU util, memory

간단한 CPU latency 측정 예시입니다.

import time
import torch

def bench(model, iters=200, warmup=50):
    model.eval()
    x = torch.randn(1, 3, 224, 224)

    with torch.inference_mode():
        for _ in range(warmup):
            _ = model(x)

        t0 = time.time()
        for _ in range(iters):
            _ = model(x)
        t1 = time.time()

    return (t1 - t0) * 1000 / iters

# fp32 vs int8 비교
# print("fp32 ms:", bench(model_fp32))
# print("int8 ms:", bench(model_int8))

실전에서는 스레드 수(torch.set_num_threads), 배치 크기, 입력 크기, NUMA, 컨테이너 CPU quota에 따라 결과가 크게 바뀝니다. 운영 환경과 최대한 비슷한 조건에서 재야 합니다.

PTQ가 실패하는 대표 케이스와 QAT로 넘어가는 기준

다음 중 하나라도 해당하면 QAT를 검토할 가치가 큽니다.

  • PTQ 적용 후 정확도 하락이 SLA를 초과
  • 입력 분포가 시간에 따라 변동(예: 광고/추천, 사용자 생성 콘텐츠)
  • 모델이 attention 기반이며 activation outlier가 큼
  • 캘리브레이션 데이터를 충분히 확보하기 어렵거나 대표성이 낮음

반대로 아래라면 PTQ로 끝낼 가능성이 높습니다.

  • CNN 계열, activation 분포가 비교적 안정적
  • 캘리브레이션 데이터를 넉넉히 확보 가능
  • 약간의 정확도 손실이 비용 절감 대비 허용 가능

배포 체크리스트: “변환 성공”과 “운영 안전”은 다르다

양자화 모델이 로컬에서 돌아가는 것과 운영에서 문제 없이 도는 것은 별개입니다.

  • 모델 저장/로딩
    • state_dict만 저장할지, torch.jit.trace/script로 패키징할지 결정
  • CPU 백엔드 일치
    • 개발 머신과 운영 머신이 같은 엔진(fbgemm)을 쓰는지 확인
  • 폴백(fallback) 전략
    • 특정 연산이 INT8 미지원이면 FP32로 폴백되며 성능 이점이 사라질 수 있음
  • 입력 전처리 일관성
    • 캘리브레이션과 운영 입력 스케일/정규화가 다르면 activation range가 틀어질 수 있음

운영 장애 관점에서 보면 “한 번에 크게 바꾸지 말고 점진적으로 롤아웃하고, 재시도/백오프 같은 보호 장치”가 중요합니다. API 트래픽 제어와 재시도 설계가 필요하다면 OpenAI 429/Rate Limit 대응 - 재시도·백오프·큐잉 글의 패턴도 서버 운영에 그대로 응용할 수 있습니다.

결론: 실전 선택 가이드

  • 빠른 승부: PTQ부터
    • 캘리브레이션 데이터만 잘 준비해도 꽤 많은 모델이 통과합니다.
  • 정확도가 핵심: QAT로 보강
    • PTQ에서 정확도가 무너지면 QAT가 가장 확실한 해법인 경우가 많습니다.
  • 성공의 80퍼센트는 데이터
    • PTQ는 캘리브레이션 대표성이, QAT는 fine-tune 데이터 품질이 성패를 좌우합니다.

다음 단계로는 모델 구조(Conv/Transformer), 목표 하드웨어(CPU/모바일), 그리고 PyTorch 버전에 따라 FX Graph Mode Quantization이나 torch.compile과의 조합까지 고려하게 됩니다. 하지만 대부분의 제품에서는 이 글의 PTQ/QAT 파이프라인만 제대로 정리해도 INT8 전환의 첫 프로덕션 릴리스까지 충분히 도달할 수 있습니다.