Published on

PyTorch 2.1 PTQ로 INT8 경량화 - 정확도 유지

Authors

서빙 환경에서 모델을 더 빠르고 싸게 돌리는 가장 현실적인 방법 중 하나가 INT8 양자화입니다. 특히 학습을 다시 하지 않고도 적용 가능한 PTQ(Post-Training Quantization)는 운영 부담이 낮아, CPU 추론이나 엣지 배포에서 자주 선택됩니다.

다만 INT8은 단순히 dtype만 바꾸는 문제가 아닙니다. 어디를 INT8로 내리고(가중치/활성값), 어떤 관측 통계로 스케일을 잡고(캘리브레이션), 어떤 백엔드로 실행할지(FBGEMM/QNNPACK/XNNPACK), 그리고 모델 구조(Conv, Linear, Attention)에 따라 정확도 손실 양상이 크게 달라집니다.

이 글에서는 PyTorch 2.1 기준으로 PTQ를 적용해 정확도 손실을 최소화하는 흐름을 정리합니다. 예시는 TorchVision 분류 모델로 설명하지만, 원리는 대부분의 CNN/MLP/간단한 Transformer 블록에도 동일하게 적용됩니다.

관련해서 로컬 LLM을 돌릴 때 메모리 압박을 줄이는 방법은 별도 글인 Transformers 로컬 LLM CUDA OOM 줄이는 9가지도 함께 참고하면 좋습니다. 양자화는 OOM을 줄이는 강력한 카드 중 하나지만, 배치/캐시/커널 선택 같은 운영 튜닝과 같이 봐야 체감 효과가 큽니다.

PTQ와 INT8의 핵심 개념

PTQ가 하는 일

PTQ는 학습된 FP32(또는 FP16/BF16) 모델을 가져와서 다음을 수행합니다.

  • 가중치(Weight)와 활성값(Activation)을 INT8 범위로 표현하기 위한 스케일과 제로포인트를 결정
  • 일부 연산을 INT8 커널로 치환(예: Linear, Conv2d)
  • 양자화/역양자화 노드를 그래프에 삽입하거나, 아예 양자화된 모듈로 교체

정확도 손실은 주로 활성값 분포를 잘못 추정하거나(캘리브레이션 부족), 이상치(outlier) 때문에 스케일이 과도하게 커져 유효 비트가 줄어드는 상황에서 발생합니다.

INT8 양자화에서 중요한 선택지

  • Per-tensor vs per-channel(보통 가중치는 per-channel이 정확도가 더 좋음)
  • Symmetric vs affine(제로포인트 포함)
  • 정적(static) vs 동적(dynamic)
    • 동적 양자화는 주로 Linear 계열에서 입력 활성값을 런타임에 스케일링해 캘리브레이션 부담이 적음
    • 정적 양자화는 캘리브레이션이 필요하지만 Conv 등 더 넓은 커버리지와 더 좋은 성능을 기대할 수 있음

PyTorch 2.1에서도 많은 경우 torch.ao.quantization 계열 API를 사용하며, 백엔드로는 x86 서버 CPU에서는 보통 fbgemm이 선택됩니다.

어떤 PTQ 경로를 선택할까

실무에서 가장 흔한 선택지는 다음 2가지입니다.

  1. 동적 양자화(dynamic quantization)
  • 장점: 캘리브레이션 데이터가 없어도 적용 가능, 적용이 간단
  • 단점: Conv 중심 모델에는 효과가 제한적, 성능 이득이 정적 대비 낮을 수 있음
  1. 정적 양자화(static quantization)
  • 장점: Conv/Linear 모두 커버, CPU에서 성능 이득이 더 크기 쉬움
  • 단점: 캘리브레이션 필수, 준비가 부족하면 정확도 손실이 커질 수 있음

이미지 분류 CNN이라면 정적 양자화 쪽이 정석이고, 텍스트 분류나 MLP 위주라면 동적 양자화로도 충분한 경우가 많습니다.

준비: 재현 가능한 평가 루프 만들기

정확도 유지가 목표라면, 양자화 전후를 같은 조건에서 비교할 수 있는 평가 루프가 먼저 필요합니다.

  • model.eval() 고정
  • seed 고정
  • 전처리 동일
  • 대표성 있는 검증셋(또는 최소한의 샘플셋) 확보

아래 코드는 TorchVision의 resnet18을 예시로, 간단한 top-1 정확도 평가 루프의 뼈대를 보여줍니다.

import time
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms

torch.manual_seed(0)

def build_val_loader(data_dir: str, batch_size: int = 64, num_workers: int = 4):
    tfm = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    ds = torchvision.datasets.ImageFolder(data_dir, transform=tfm)
    return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

@torch.inference_mode()
def eval_top1(model, loader, device="cpu"):
    model.eval().to(device)
    correct = 0
    total = 0
    t0 = time.time()
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    dt = time.time() - t0
    return correct / max(total, 1), dt

이제 FP32 baseline을 측정하고, 양자화 모델의 정확도와 속도를 같은 루프에서 비교하면 됩니다.

방법 1: 동적 양자화로 빠르게 INT8 적용

동적 양자화는 Linear 계열에 특히 유용합니다. 예를 들어 간단한 MLP, 일부 RNN, 또는 Transformer의 FFN(선형층) 위주에서 효과가 납니다.

import torch
from torch.ao.quantization import quantize_dynamic

model_fp32 = torchvision.models.resnet18(weights=None)
model_fp32.eval()

# 동적 양자화는 주로 Linear에 적용 (Conv는 보통 대상 아님)
model_int8_dyn = quantize_dynamic(
    model_fp32,
    {torch.nn.Linear},
    dtype=torch.qint8,
)

print(model_int8_dyn)

정확도 손실이 적고 적용이 쉬운 대신, ResNet처럼 Conv 비중이 큰 모델에서는 속도 이득이 제한적일 수 있습니다. 반대로 텍스트/추천/탭уляр(tabular) 계열의 선형 모델에서는 꽤 좋은 선택입니다.

방법 2: 정적 PTQ로 Conv까지 INT8화

정적 양자화는 크게 다음 순서를 따릅니다.

  1. 백엔드 선택: x86 서버라면 fbgemm
  2. qconfig 설정
  3. fuse_modules로 Conv+BN+ReLU 같은 패턴을 fuse
  4. prepare로 observer 삽입
  5. 캘리브레이션 데이터로 몇 배치 추론해 통계 수집
  6. convert로 양자화 모듈로 변환

1) 백엔드와 qconfig

import torch
import torchvision
from torch.ao.quantization import get_default_qconfig

torch.backends.quantized.engine = "fbgemm"  # x86 서버 CPU 기준

model_fp32 = torchvision.models.resnet18(weights=None)
model_fp32.eval()

qconfig = get_default_qconfig("fbgemm")
model_fp32.qconfig = qconfig

2) Fusion: 정확도와 성능의 출발점

Conv+BN은 추론 시 하나로 합칠 수 있고, ReLU까지 묶으면 양자화 친화적인 패턴이 됩니다. Fusion을 하지 않으면 성능이 떨어지거나, 양자화 경로가 꼬여 정확도 손실이 커질 수 있습니다.

ResNet은 블록 구조가 복잡하므로, 실무에서는 TorchVision의 quantization-ready 모델을 쓰거나, FX Graph Mode를 활용하는 편이 안전합니다. 여기서는 개념 전달을 위해 FX Graph Mode 기반의 흐름을 보여줍니다.

3) FX Graph Mode PTQ 예시

import torch
import torchvision
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

torch.backends.quantized.engine = "fbgemm"

model_fp32 = torchvision.models.resnet18(weights=None)
model_fp32.eval()

qconfig_mapping = get_default_qconfig_mapping("fbgemm")

# 예시 입력(형상 추적용)
example_inputs = (torch.randn(1, 3, 224, 224),)

prepared = prepare_fx(model_fp32, qconfig_mapping, example_inputs)

# 캘리브레이션: 대표 데이터 몇 배치로 observer 통계 수집
@torch.inference_mode()
def calibrate(model, loader, num_batches: int = 20):
    model.eval()
    for i, (x, _) in enumerate(loader):
        model(x)
        if i + 1 >= num_batches:
            break

# val_loader는 앞에서 만든 DataLoader를 사용
# calibrate(prepared, val_loader, num_batches=20)

model_int8 = convert_fx(prepared)

위 코드에서 가장 중요한 포인트는 캘리브레이션입니다. num_batches를 너무 적게 잡으면 활성값 분포를 제대로 못 보고 스케일이 불안정해져 정확도 손실이 커집니다. 반대로 너무 많이 잡는다고 무조건 좋아지진 않지만, 일반적으로 데이터 다양성이 확보될수록 안전합니다.

정확도 손실을 줄이는 캘리브레이션 전략

PTQ에서 정확도는 캘리브레이션 품질에 크게 좌우됩니다.

1) 캘리브레이션 데이터는 “대표성”이 전부

  • 운영 입력 분포와 유사한 데이터 사용
  • 클래스 균형이 깨져도 되지만, 조명/해상도/길이/노이즈 등 입력 다양성은 확보
  • 전처리 파이프라인을 운영과 동일하게 유지

예를 들어 이미지 모델이라면 리사이즈/크롭/정규화가 조금만 달라도 activation 분포가 달라져 스케일이 흔들릴 수 있습니다.

2) 배치 수보다 “커버리지”

  • 20배치가 200배치보다 나을 수도, 반대일 수도 있습니다.
  • 핵심은 outlier를 포함한 분포를 얼마나 잘 포함하느냐입니다.

3) 관측기(observer) 선택과 outlier 대응

기본 MinMaxObserver는 outlier에 취약합니다. outlier가 큰 경우 스케일이 커져 대부분 값이 INT8의 좁은 구간에 몰리며 정보가 손실됩니다.

가능한 대응:

  • 히스토그램 기반 observer(예: HistogramObserver)로 KL 기반 스케일 추정
  • 채널별(per-channel) 가중치 양자화 사용
  • 특정 레이어는 FP32로 남기는 혼합 정밀도(Selective quantization)

PyTorch의 기본 qconfig는 백엔드별로 합리적인 기본값을 제공하지만, 모델/데이터에 따라 observer를 바꾸는 것만으로도 정확도 손실이 크게 줄어드는 케이스가 있습니다.

“정확도 유지”를 위한 실전 체크리스트

1) 양자화 제외 레이어를 과감히 지정

모델에서 민감한 구간(첫 Conv, 마지막 FC, LayerNorm/Softmax 주변 등)은 양자화에서 제외하면 정확도를 지키기 쉬워집니다. 특히 Transformer 계열에서 LayerNorm은 INT8로 내리면 손실이 커지는 경우가 많습니다.

FX Graph Mode에서는 모듈 이름 패턴으로 qconfig를 다르게 매핑할 수 있습니다.

from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver

act_obs = MinMaxObserver.with_args(dtype=torch.quint8)
wt_obs = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)

custom_mapping = (
    QConfigMapping()
    .set_global(torch.ao.quantization.QConfig(activation=act_obs, weight=wt_obs))
    # 예: 마지막 분류기는 FP32로 유지
    .set_module_name("fc", None)
)

주의할 점은 모듈 이름이 실제 모델과 일치해야 한다는 것입니다. print(model)로 구조를 확인한 뒤 적용하세요.

2) 정확도만 보지 말고 “실제 지연시간”을 측정

INT8로 바꿨는데도 느려지는 경우가 있습니다.

  • 작은 배치에서 오히려 overhead가 커짐
  • 스레드 설정이 비효율적
  • 일부 연산이 양자화 커널로 치환되지 못해 dequant-quant가 반복됨

CPU 추론에서는 아래 설정이 체감에 큰 영향을 줍니다.

import torch

torch.set_num_threads(8)
torch.set_num_interop_threads(1)

환경에 따라 최적값이 다르니 벤치마크로 결정해야 합니다.

3) 양자화가 “모델 메모리”만 줄이는지, “활성 메모리”까지 줄이는지 구분

PTQ INT8은 가중치 저장 공간을 크게 줄이지만, 프레임워크/연산 경로에 따라 활성 메모리는 기대만큼 줄지 않을 수 있습니다. 특히 그래프 중간에 FP32 경로가 남아 있으면 활성 텐서가 FP32로 유지될 수 있습니다.

로컬 추론에서 OOM이나 메모리 압박이 문제라면 양자화 외에도 캐시/배치/입력 길이 관리가 중요하고, 상황별 팁은 Transformers 로컬 LLM CUDA OOM 줄이는 9가지에서 더 확장된 관점으로 다룹니다.

벤치마크: FP32 vs INT8 비교 템플릿

아래는 동일한 평가 루프에서 정확도와 시간을 비교하는 예시입니다.

# val_loader = build_val_loader("/path/to/val")

acc_fp32, t_fp32 = eval_top1(model_fp32, val_loader, device="cpu")
acc_int8, t_int8 = eval_top1(model_int8, val_loader, device="cpu")

print(f"FP32  acc={acc_fp32:.4f}, time={t_fp32:.2f}s")
print(f"INT8  acc={acc_int8:.4f}, time={t_int8:.2f}s")
print(f"Speedup: {t_fp32 / max(t_int8, 1e-9):.2f}x")

정확도 손실이 허용 범위를 넘는다면 다음 순서로 원인을 좁히는 것이 효율적입니다.

  1. 캘리브레이션 배치 수/데이터 다양성 확대
  2. observer 변경(히스토그램 기반 등)
  3. 민감 레이어 제외(첫/마지막, 정규화 계열)
  4. 정적 대신 동적 양자화로 후퇴(또는 혼합)

PyTorch 2.1에서 자주 겪는 함정

1) 캘리브레이션을 train() 상태로 돌리는 실수

Dropout/BatchNorm이 활성화되면 통계가 오염됩니다. 반드시 eval() + inference_mode()로 캘리브레이션을 수행하세요.

2) 전처리 불일치

캘리브레이션 데이터의 전처리가 운영과 다르면 activation 분포가 달라져 정확도 손실이 커질 수 있습니다. 특히 정규화(mean/std) 누락은 치명적입니다.

3) “양자화 됐는데 왜 안 빨라지지” 문제

  • 모델이 실제로 quantized kernel을 타는지 확인 필요
  • 연산 그래프에 quant-dequant가 많이 남아 있으면 오히려 손해

이 경우 torch.profiler로 연산 비중을 보고, 어떤 op가 병목인지 먼저 확인하는 게 정답입니다.

운영 관점: 배포 전 확인할 것

  • 타겟 CPU에서 백엔드가 기대대로 동작하는지(예: fbgemm)
  • 스레드/NUMA 설정
  • 입력 크기/배치 크기에서의 P95 지연시간
  • 모델 저장 포맷(torch.save vs torch.export/torch.jit)과 로딩 시간

서빙이 Cloud Run 같은 환경이라면 콜드스타트/CPU 제한/동시성 설정이 체감 성능에 큰 영향을 줍니다. 네트워크/플랫폼 병목까지 포함해 점검하려면 GCP Cloud Run 504 타임아웃 원인·해결 9가지처럼 인프라 관점의 체크리스트도 같이 보는 편이 안전합니다.

마무리

PyTorch 2.1에서 PTQ로 INT8 경량화를 성공시키는 핵심은 정확도 손실을 줄이는 캘리브레이션민감 레이어를 선별적으로 보호하는 전략입니다. 동적 양자화는 빠른 적용에 적합하고, 정적 양자화(FX Graph Mode)는 더 큰 성능 이득을 노릴 수 있지만 캘리브레이션 품질이 승부를 가릅니다.

실무에서는 다음 목표를 명확히 두고 선택하면 시행착오가 줄어듭니다.

  • “일단 쉽게”: 동적 양자화로 손실/이득 감 잡기
  • “성능 극대화”: 정적 PTQ + 캘리브레이션/observer/제외 레이어 튜닝
  • “정확도 최우선”: 혼합 정밀도(일부 FP32 유지)로 안정화

이 흐름으로 접근하면 INT8의 비용 절감 효과를 누리면서도, 운영에서 문제가 되는 정확도 하락을 상당 부분 통제할 수 있습니다.