Published on

PyTorch 2.x PTQ로 INT8 양자화, 정확도 지키기

Authors

서버/엣지 추론 비용을 줄이기 위해 INT8 양자화는 거의 필수 옵션이 됐습니다. 다만 PTQ(Post-Training Quantization)는 학습을 다시 하지 않는 대신, 캘리브레이션과 설정이 부정확하면 정확도가 크게 흔들릴 수 있습니다. 이 글은 PyTorch 2.x 환경에서 torch.ao.quantization 기반 PTQ로 INT8을 적용하면서 정확도를 최대한 유지하는 방법을 코드와 함께 정리합니다.

또한 로컬 LLM이나 대형 모델 최적화의 큰 흐름(메모리, 속도, 배치 전략)과 함께 보면 전체 시스템 최적화 관점이 잡힙니다. 필요하면 Transformers 로컬 LLM OOM·속도 최적화 가이드도 같이 참고하세요.

PTQ INT8에서 정확도가 깨지는 대표 원인

PTQ에서 정확도 하락은 대개 아래 케이스로 귀결됩니다.

  1. 캘리브레이션 데이터 분포 불일치: 실제 트래픽과 다른 입력으로 통계(min/max, 히스토그램)를 잡으면 스케일이 틀어집니다.
  2. 관측자(observer) 선택 오류: MinMaxObserver는 outlier에 취약하고, 히스토그램 기반은 데이터가 부족하면 불안정할 수 있습니다.
  3. per-tensor vs per-channel 선택 미스: 특히 Conv/Linear weight는 per-channel이 정확도에 유리한 경우가 많습니다.
  4. 활성화(activation) 동적 범위가 큰 구간: attention, residual add, layernorm 주변은 INT8 손실이 커지기 쉽습니다.
  5. 양자화 대상 레이어를 무리하게 확대: “전부 INT8”이 항상 정답이 아닙니다. 일부 민감 레이어는 FP16/FP32로 남기는 편이 전체 정확도에 이득입니다.

PyTorch 2.x PTQ 기본 파이프라인

PyTorch 2.x에서 eager mode PTQ는 보통 다음 순서로 진행합니다.

  1. 모델을 eval()로 전환
  2. qconfig 설정(관측자, per-channel 등)
  3. prepare로 관측자 삽입
  4. 캘리브레이션 데이터로 forward 수행(통계 수집)
  5. convert로 INT8 연산으로 변환

아래 예시는 FX Graph Mode가 아니라 eager mode를 기준으로 하되, 핵심 개념(관측자/캘리브레이션/제외 전략)은 동일하게 적용됩니다.

예제: ResNet 계열을 PTQ INT8로 변환

import torch
import torchvision

from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization import prepare, convert

# 1) 모델 로드 및 eval
model = torchvision.models.resnet18(weights=None)
model.eval()

# 2) 백엔드 선택: x86 서버면 보통 fbgemm, ARM이면 qnnpack
#    PyTorch는 엔진 설정에 따라 커널이 달라집니다.
torch.backends.quantized.engine = "fbgemm"

# 3) qconfig 설정
model.qconfig = get_default_qconfig("fbgemm")

# 4) 관측자 삽입
prepared = prepare(model)

# 5) 캘리브레이션: 대표 샘플로 forward만 수행
#    실제 서비스 입력 분포를 최대한 반영하는 게 중요합니다.
with torch.inference_mode():
    for _ in range(32):
        x = torch.randn(1, 3, 224, 224)
        prepared(x)

# 6) 변환
quantized_model = convert(prepared)

# 추론
with torch.inference_mode():
    y = quantized_model(torch.randn(1, 3, 224, 224))
    print(y.shape)

위 코드는 “돌아가는 최소 예시”에 가깝고, 정확도 유지 관점에서는 추가 설정이 필요합니다. 이제부터가 핵심입니다.

정확도 유지 팁 1: 캘리브레이션 데이터는 “대표성”이 전부

PTQ는 캘리브레이션에서 수집한 통계를 기준으로 스케일과 제로포인트를 정합니다. 따라서 캘리브레이션 데이터는 아래 조건을 만족해야 합니다.

  • 실제 서비스 입력과 전처리/정규화가 동일
  • 클래스 분포, 길이 분포(텍스트), 해상도/조명(이미지) 등 분포가 유사
  • outlier가 있는 서비스라면 outlier도 포함(다만 너무 과하면 역효과)

실무 팁:

  • “랜덤 32개” 같은 소량은 위험합니다. 최소 수백~수천 샘플을 권장합니다(모델/도메인에 따라 다름).
  • 배치 크기도 실제 추론과 비슷하게 맞추면 activation 통계가 더 현실적입니다.

정확도 유지 팁 2: 관측자(observer)와 스킴(scheme) 선택

PyTorch 기본 get_default_qconfig는 무난하지만, 모델 특성에 맞게 조정하면 정확도를 더 지킬 수 있습니다.

per-channel weight 양자화는 기본으로 고려

Conv/Linear weight는 per-channel이 per-tensor보다 손실이 적은 경우가 많습니다.

import torch
from torch.ao.quantization import QConfig
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver

activation_observer = HistogramObserver.with_args(
    dtype=torch.quint8,
    qscheme=torch.per_tensor_affine,
    reduce_range=False,
)

weight_observer = PerChannelMinMaxObserver.with_args(
    dtype=torch.qint8,
    qscheme=torch.per_channel_symmetric,
)

qconfig = QConfig(activation=activation_observer, weight=weight_observer)
model.qconfig = qconfig
  • activation은 히스토그램 기반이 outlier에 덜 민감한 편입니다.
  • weight는 대칭(symmetric) + per-channel 조합이 흔히 좋은 출발점입니다.

주의: 히스토그램 관측자는 캘리브레이션 데이터가 너무 적으면 오히려 불안정할 수 있습니다. 데이터가 적다면 MinMaxObserver가 더 나을 때도 있습니다.

정확도 유지 팁 3: “민감 레이어”는 과감히 제외(Selective Quantization)

PTQ에서 가장 흔한 실전 해법은 일부 레이어를 FP로 남기는 것입니다. 특히 아래 구간은 민감한 경우가 많습니다.

  • 출력 logits 직전 Linear
  • residual branch의 add 주변
  • LayerNorm, Softmax, attention score 계산부

Eager mode에서는 모듈 단위로 qconfig = None 처리로 제외할 수 있습니다.

import torch.nn as nn

def disable_quant_for_layernorm(model: nn.Module):
    for name, m in model.named_modules():
        if isinstance(m, nn.LayerNorm):
            m.qconfig = None

model.qconfig = qconfig

disable_quant_for_layernorm(model)
prepared = prepare(model)

추가로 “마지막 분류기(FC)만 FP로 유지”도 자주 쓰는 전략입니다.

# 예: ResNet의 마지막 fc 제외
model.fc.qconfig = None

정확도 하락이 큰데 latency 이득이 충분하면, 전체 레이어를 INT8로 미는 것보다 이런 절충이 더 좋은 결과를 줍니다.

정확도 유지 팁 4: Fusion(연산 결합)으로 양자화 손실과 오버헤드 줄이기

Conv-BN-ReLU 같은 패턴은 fuse하면 양자화에 유리하고 성능도 좋아집니다. PyTorch는 일부 모델에 대해 fuse 유틸을 제공합니다.

import torch
import torch.nn as nn
from torch.ao.quantization import fuse_modules

class SmallCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.relu(self.bn(self.conv(x)))
        x = self.pool(x).flatten(1)
        return self.fc(x)

model = SmallCNN().eval()

# fuse: ["conv", "bn", "relu"]
model_fused = fuse_modules(model, [["conv", "bn", "relu"]], inplace=False)

fusion은 단순 성능 최적화가 아니라, BN이 Conv에 흡수되면서 스케일이 안정화되어 양자화 오차가 줄어드는 효과도 기대할 수 있습니다.

정확도 유지 팁 5: 캘리브레이션 시 “실제 전처리”를 그대로 재현

이미지 모델은 정규화(mean/std), 리사이즈, 센터 크롭 등 전처리가 activation 범위를 크게 바꿉니다. 텍스트 모델은 토크나이저 버전/패딩/트렁케이션 전략이 통계에 영향을 줍니다.

캘리브레이션 파이프라인은 반드시 추론 파이프라인과 동일하게 유지하세요.

from torchvision import transforms

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 캘리브레이션에서도 동일 preprocess 적용

정확도 유지 팁 6: 평가 지표를 “정확도 하나”로 단순화하지 않기

PTQ 적용 후에는 단순 top-1 accuracy 외에도 다음을 같이 봐야 합니다.

  • 클래스별 precision/recall 변화(특정 클래스만 무너질 수 있음)
  • confidence calibration(softmax 확률 분포가 바뀌는지)
  • 임계값 기반 태스크라면 threshold 재튜닝 필요 여부

특히 운영 환경에서 “정확도 하락”은 사용자 체감과 다르게 나타납니다. 예를 들어 상위 1개 라벨은 유지되지만 확률이 낮아져 후속 로직(필터링/거절)이 바뀔 수 있습니다.

정확도 유지 팁 7: 성능 병목은 양자화만으로 안 끝난다

INT8로 줄인 메모리/연산 이득이 실제 latency로 직결되지 않는 경우가 많습니다.

  • 전처리/후처리 CPU 비용
  • 배치/스레딩 설정
  • 모델 로딩 및 warm-up
  • I/O 및 RPC 오버헤드

로컬 LLM/대형 추론에서 OOM과 속도 병목을 함께 다루는 관점은 Transformers 로컬 LLM OOM·속도 최적화 가이드에 정리해두었습니다. “INT8 자체”보다 “전체 파이프라인”에서 시간을 먹는 구간을 먼저 확인하는 게 실전에서 더 큰 개선으로 이어집니다.

실전 체크리스트

아래 순서대로 점검하면 PTQ INT8의 실패 확률이 크게 줄어듭니다.

  1. 캘리브레이션 데이터가 실제 입력 분포를 대표하는가
  2. 전처리/토크나이저/패딩 전략이 운영과 동일한가
  3. weight는 per-channel을 적용했는가
  4. activation observer를 데이터 양/아웃라이어 특성에 맞게 골랐는가
  5. 민감 레이어(LayerNorm, 마지막 FC 등)를 제외해봤는가
  6. fusion 가능한 패턴을 fuse했는가
  7. 정확도 외에 confidence/클래스별 지표도 비교했는가

마무리

PyTorch 2.x PTQ INT8은 “버튼 한 번”으로 끝나는 기능처럼 보이지만, 실제로는 캘리브레이션과 관측자 선택, 그리고 선택적 양자화 전략이 정확도를 좌우합니다. 특히 분포 대표성을 갖춘 캘리브레이션 데이터와 per-channel weight, 민감 레이어 제외만 제대로 적용해도 체감 성능을 유지한 채 비용을 크게 줄일 수 있습니다.

다음 단계로는 모델 특성에 따라 FX Graph Mode PTQ나(그래프 기반으로 더 정교한 제어) QAT(학습 기반 양자화)까지 검토하는 것이 좋습니다. PTQ로 목표 정확도를 만족하지 못한다면, 그때가 QAT로 넘어갈 타이밍입니다.