Published on

PyTorch PTQ/QAT로 INT8 양자화 - 정확도 유지

Authors

서빙 환경에서 모델을 더 빠르고 가볍게 만들고 싶을 때, 가장 먼저 검토하는 카드가 INT8 양자화입니다. FP32 대비 메모리 사용량을 줄이고(가중치 기준 약 4분의 1), CPU 추론에서 특히 유의미한 속도 향상을 기대할 수 있습니다. 문제는 정확도입니다. 같은 INT8이라도 PTQ(Post-Training Quantization) 로 끝낼지, QAT(Quantization-Aware Training) 까지 갈지에 따라 정확도 손실과 개발 비용이 크게 달라집니다.

이 글은 PyTorch(특히 torch.ao.quantization)에서 PTQ/QAT를 적용해 정확도 하락을 최소화하는 방법을, 실수하기 쉬운 포인트(캘리브레이션, 관측자, fusion, per-channel, backend 설정) 중심으로 정리합니다.

또한 운영 관점에서 “성능 최적화는 결국 병목을 줄이는 일”이라는 점에서, 재시도 폭주/타임아웃 설계처럼 시스템 레벨에서의 성능 안정화도 함께 고려해야 합니다. 관련해서는 gRPC MSA에서 데드라인·리트라이 폭주 막는 법도 함께 참고하면 좋습니다.

INT8 양자화 기본: 무엇이 바뀌나

INT8 양자화는 크게 두 가지를 바꿉니다.

  1. 가중치(Weight) 양자화: FP32 가중치를 INT8로 저장
  2. 활성값(Activation) 양자화: 레이어 출력(activation)을 INT8로 표현

양자화는 대개 다음 형태로 표현됩니다.

  • x_int8 = clamp(round(x_fp32 / scale) + zero_point)
  • x_fp32 ≈ (x_int8 - zero_point) * scale

여기서 핵심은 scalezero_point를 어떻게 잘 잡느냐입니다. 이 값을 잘못 잡으면 clipping/rounding 오차가 커져 정확도가 떨어집니다.

PTQ vs QAT 선택 기준

  • PTQ: 학습 없이(또는 최소한의 튜닝으로) 양자화. 빠르고 싸지만 정확도 손실이 날 수 있음.
  • QAT: 학습 중에 양자화 오차를 “보게” 만들어 모델이 적응하도록 함. 정확도는 유리하지만 학습 비용이 듦.

실무적으로는 다음 순서를 권합니다.

  1. Dynamic Quantization(가능하면)으로 빠르게 이득 확인
  2. Static PTQ + 캘리브레이션/옵저버 튜닝
  3. 그래도 정확도 안 나오면 QAT

PyTorch 양자화 스택: torch.ao.quantization 개요

PyTorch는 과거 torch.quantization에서 현재 torch.ao.quantization(AO: Architecture Optimization)로 정리되었습니다. CPU 백엔드는 보통 다음 중 하나를 사용합니다.

  • fbgemm: x86 서버 CPU에서 주로 사용
  • qnnpack: ARM/모바일 계열에서 주로 사용

백엔드에 따라 지원되는 연산과 성능이 달라서, 개발 머신과 배포 머신의 CPU 아키텍처가 다르면 결과가 달라질 수 있습니다.

PTQ(Static) 실전: 정확도 유지의 핵심은 캘리브레이션

Static PTQ는 활성값까지 INT8로 만들기 때문에, 캘리브레이션 데이터로 activation 분포를 잘 관측하는 게 승부처입니다.

1) 모듈 fusion 먼저

Conv-BN-ReLU 같은 패턴은 fusion을 하면 수치적으로도 유리하고(특히 BN folding), INT8 커널 매칭도 좋아집니다.

아래는 전형적인 CNN에서의 흐름 예시입니다.

import torch
import torch.nn as nn
import torch.ao.quantization as quant

class SmallCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(32, 10)

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x).flatten(1)
        return self.fc(x)

model_fp32 = SmallCNN().eval()

# fusion: Conv+BN+ReLU 패턴을 묶음
# Sequential 내부 인덱스를 정확히 지정해야 함
model_fp32.fuse_model = lambda: torch.ao.quantization.fuse_modules(
    model_fp32,
    [
        ["features.0", "features.1", "features.2"],
        ["features.3", "features.4", "features.5"],
    ],
    inplace=True,
)

model_fp32.fuse_model()

fusion 인덱스가 틀리면 조용히 실패하거나(혹은 일부만 적용) 성능/정확도 모두 손해를 봅니다. 모델 구조가 복잡하면 print(model)로 모듈 경로를 먼저 확정하세요.

2) qconfig 설정: per-channel weight는 거의 필수

정확도 유지에 가장 큰 영향을 주는 설정 중 하나가 가중치 per-channel 양자화입니다. Conv/Linear weight를 채널별로 다른 scale로 양자화하면 오차가 확 줄어드는 경우가 많습니다.

import torch
import torch.ao.quantization as quant

# x86 서버라면 보통 fbgemm
torch.backends.quantized.engine = "fbgemm"

model_fp32.qconfig = quant.get_default_qconfig("fbgemm")

# 참고: 더 공격적인 설정(예: activation histogram observer)은 상황에 따라 조정
# model_fp32.qconfig = quant.QConfig(
#     activation=quant.HistogramObserver.with_args(dtype=torch.quint8),
#     weight=quant.default_per_channel_weight_observer
# )

get_default_qconfig는 “대체로 무난한” 선택입니다. 정확도가 모자라면 activation 옵저버를 HistogramObserver로 바꿔 clipping을 완화하거나, 캘리브레이션 샘플 수를 늘리는 쪽을 먼저 시도하는 편이 안전합니다.

3) prepare 후 캘리브레이션 수행

prepare는 옵저버를 삽입하고, 캘리브레이션 동안 통계를 모읍니다.

import torch
import torch.ao.quantization as quant

model_prepared = quant.prepare(model_fp32, inplace=False)

# 캘리브레이션: 실제 서빙 입력 분포를 반영한 데이터로 돌려야 함
# 예시는 더미 데이터
with torch.inference_mode():
    for _ in range(200):
        x = torch.randn(32, 3, 224, 224)
        _ = model_prepared(x)

캘리브레이션 데이터는 “학습 데이터 아무거나”가 아니라, 서빙 트래픽과 유사한 분포가 중요합니다.

  • 전처리(정규화/리사이즈/크롭)까지 동일해야 함
  • 야간/저조도/특정 카메라 등 실제 변동성을 포함해야 함
  • 배치 크기는 큰 의미가 없지만, 샘플 다양성은 중요

4) convert로 INT8 모델 생성

model_int8 = quant.convert(model_prepared, inplace=False)

# 추론
with torch.inference_mode():
    y = model_int8(torch.randn(1, 3, 224, 224))

이제 model_int8는 양자화된 모듈(예: QuantizedConv2d)을 포함합니다.

PTQ에서 정확도 떨어질 때 체크리스트

  1. fusion 누락: Conv-BN-ReLU 미융합은 정확도/성능 모두 손해
  2. 캘리브레이션 부족: 샘플 수가 적거나 분포가 다르면 activation scale이 망가짐
  3. outlier: 극단값이 activation range를 넓혀 정밀도가 떨어짐
  4. 레이어 민감도: 첫 Conv, 마지막 FC, attention 계열은 특히 민감할 수 있음
  5. 연산 미지원 fallback: 일부 연산이 FP32로 남아 경계에서 오차가 커지기도 함

PTQ(Dynamic)로 “안전한” 첫 이득 보기

Transformer나 RNN 계열에서 Linear가 대부분일 때는 dynamic quantization이 빠르고 안정적입니다. activation을 런타임에 동적으로 스케일링하고, 주로 weight를 INT8로 바꿉니다.

import torch
import torch.nn as nn

model_fp32 = nn.Sequential(
    nn.Linear(768, 768),
    nn.ReLU(),
    nn.Linear(768, 2),
).eval()

model_int8_dyn = torch.ao.quantization.quantize_dynamic(
    model_fp32,
    {nn.Linear},
    dtype=torch.qint8,
)

정확도 손실이 상대적으로 적고 적용이 쉬워서, “PTQ로 될까?”를 빠르게 확인하는 용도로 좋습니다.

QAT 실전: 정확도를 지키는 정공법

PTQ로 정확도가 충분히 나오지 않는다면 QAT가 답입니다. QAT는 학습 중 forward에 fake quantization을 삽입해, 모델이 양자화 오차에 적응하도록 만듭니다.

QAT의 핵심 포인트

  • 학습률을 낮추고(특히 후반) 짧게 파인튜닝하는 경우가 많음
  • BatchNorm을 어떻게 다룰지 중요(동결/폴딩 타이밍)
  • 학습 데이터가 서빙 분포를 대표해야 함(PTQ보다 더 중요)

QAT 코드 예시(간단 파이프라인)

import torch
import torch.nn as nn
import torch.ao.quantization as quant

torch.backends.quantized.engine = "fbgemm"

model = SmallCNN()
model.train()

# 1) fusion
model.fuse_model()

# 2) QAT qconfig
model.qconfig = quant.get_default_qat_qconfig("fbgemm")

# 3) prepare_qat
model_qat = quant.prepare_qat(model, inplace=False)

optimizer = torch.optim.SGD(model_qat.parameters(), lr=1e-4, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# 예시 학습 루프(더미 데이터)
for step in range(300):
    x = torch.randn(32, 3, 224, 224)
    t = torch.randint(0, 10, (32,))

    optimizer.zero_grad(set_to_none=True)
    y = model_qat(x)
    loss = criterion(y, t)
    loss.backward()
    optimizer.step()

# 4) eval로 전환 후 convert
model_qat.eval()
model_int8 = quant.convert(model_qat, inplace=False)

실제 프로젝트에서는 더미 데이터 대신 실데이터로, 학습 스텝도 더 길게 잡습니다. 다만 QAT는 “처음부터 재학습”이 아니라, FP32 체크포인트에서 짧게 파인튜닝하는 방식이 비용 대비 효과가 좋습니다.

QAT에서 정확도 유지 팁

  • 첫/마지막 레이어는 양자화 제외를 고려(민감도 높음)
  • activation 옵저버를 histogram 기반으로 바꾸어 clipping 완화
  • per-channel weight는 유지
  • 학습 후반에 BN 통계를 안정화(필요시 BN freeze)

“일부만” 양자화하기: 민감 레이어 보호

정확도가 특정 레이어에서 크게 깨질 때는, 전체 INT8 고집 대신 부분 양자화가 실전적으로 더 낫습니다.

전략 예시:

  • backbone은 INT8, head는 FP16/FP32
  • 첫 Conv와 마지막 Linear는 FP32 유지
  • attention 블록은 FP16 유지, FFN Linear만 INT8

PyTorch에서는 모듈별로 qconfig = None을 주어 제외하는 패턴이 흔합니다(모델 구조에 맞게 적용).

# 예: 마지막 fc는 양자화 제외
model_fp32.fc.qconfig = None

이런 트레이드오프는 “정확도 목표”와 “지연시간 목표”를 함께 놓고 결정해야 합니다. 운영에서 지연시간 SLO를 맞추려면, 모델만 빠르게 만드는 것 외에 타임아웃/재시도 정책도 함께 봐야 합니다. 이 관점은 Go gRPC DEADLINE_EXCEEDED 원인과 재시도·타임아웃 설계와도 연결됩니다.

성능 측정: 속도는 반드시 엔드투엔드로 재기

양자화는 커널 속도만 빨라져도, 전처리/후처리/데이터 복사에서 병목이 남으면 체감이 약합니다. 모델 단독 벤치마크와 함께, 실제 서빙 핫패스에서 측정하세요.

간단한 마이크로 벤치 예시:

import time
import torch

def bench(model, iters=200, warmup=50):
    x = torch.randn(1, 3, 224, 224)
    model.eval()

    with torch.inference_mode():
        for _ in range(warmup):
            _ = model(x)

        t0 = time.perf_counter()
        for _ in range(iters):
            _ = model(x)
        t1 = time.perf_counter()

    return (t1 - t0) / iters

# fp32_time = bench(model_fp32)
# int8_time = bench(model_int8)
# print(fp32_time, int8_time)

주의할 점:

  • CPU 스레드 수(torch.set_num_threads)에 따라 결과가 크게 달라짐
  • 같은 머신에서 비교해야 함
  • 배치 크기와 입력 크기를 실제 트래픽에 맞춰야 함

흔한 함정: 정확도는 맞는데 운영에서 깨지는 케이스

1) 캘리브레이션/학습과 서빙 전처리가 다름

가장 흔합니다. 정규화 상수, 리사이즈 방식, 컬러 채널 순서가 다르면 activation 분포가 바뀌어 INT8에서 특히 치명적입니다.

2) 모델을 train() 상태로 서빙

옵저버/BN/드롭아웃 등으로 인해 결과가 흔들립니다. 서빙 직전에는 항상 eval()inference_mode()를 강제하세요.

3) 연산 미지원으로 부분 FP32 fallback

겉으로는 돌아가는데 성능이 안 나오거나, 경계에서 오차가 커질 수 있습니다. 변환 후 모델 그래프를 확인하고, 어떤 모듈이 quantized로 바뀌었는지 점검하세요.

4) 지연시간은 줄었는데 재시도 폭주로 비용 증가

p99가 줄면 타임아웃을 공격적으로 줄이고 싶어지지만, 분산 환경에서는 작은 흔들림이 재시도 폭주로 이어질 수 있습니다. 모델 최적화 후에는 반드시 데드라인/리트라이 정책을 재검토하세요. 이 주제는 gRPC MSA에서 데드라인·리트라이 폭주 막는 법에서 더 깊게 다룹니다.

권장 워크플로우 요약

  1. 목표 정의: 정확도 하락 허용 범위, p50/p99 지연시간, CPU/메모리 예산
  2. Dynamic quantization으로 빠른 PoC
  3. Static PTQ:
    • fusion
    • get_default_qconfig에서 시작
    • 캘리브레이션 데이터 품질/다양성 확보
  4. 부족하면 QAT:
    • 짧은 파인튜닝
    • 민감 레이어 부분 양자화 고려
  5. 엔드투엔드 벤치마크 + 운영 정책(타임아웃/리트라이) 재정렬

마무리

INT8 양자화는 “적용하면 끝”이 아니라, 분포(캘리브레이션)와 민감도(레이어별 오차) 를 다루는 작업입니다. PTQ는 캘리브레이션과 옵저버 설정이 정확도의 대부분을 좌우하고, QAT는 비용이 들지만 정확도를 지키는 가장 확실한 방법입니다.

다음 액션으로는, 현재 모델에서 PTQ를 적용한 뒤 레이어별 민감도를 측정해(예: 특정 블록만 FP32로 남겨서 비교) “어디서 깨지는지”를 먼저 찾는 것을 권합니다. 그 결과를 바탕으로 PTQ 튜닝으로 끝낼지, QAT 파인튜닝으로 갈지 판단하면 시행착오를 크게 줄일 수 있습니다.