Published on

PyTorch 2.1 PT2E로 INT8 양자화 모델 만들기

Authors

서빙 비용을 줄이거나(특히 CPU 인퍼런스), 엣지 디바이스에서 레이턴시를 낮추려면 INT8 양자화는 가장 효과적인 카드 중 하나입니다. PyTorch 2.1부터는 기존 FX Graph Mode 양자화 흐름과 더불어, 내보내기(export) 기반의 PT2E(PyTorch 2 Export) 양자화 파이프라인이 본격적으로 자리 잡았습니다.

이 글에서는 PyTorch 2.1 PT2E로 INT8 양자화 모델을 만드는 전체 흐름을, “돌아가는 코드” 중심으로 정리합니다. 특히 다음을 목표로 합니다.

  • 학습된 FP32 모델을 post-training static quantization(PTQ) 로 INT8로 변환
  • 캘리브레이션 데이터로 스케일/제로포인트를 수집
  • 백엔드(x86, qnnpack)에 맞는 qconfig 선택
  • 변환 후 정확도/성능을 측정하고, 흔한 실패 포인트를 피하기

운영 환경에서 성능 이슈를 추적하는 방식은 언어가 달라도 유사합니다. 예를 들어 분산 트레이싱으로 병목을 확인하는 관점은 OpenTelemetry로 MSA 분산 트랜잭션 추적 실전 글의 접근과도 통합니다. 양자화도 “변환이 끝”이 아니라 “측정과 검증”이 핵심입니다.

PT2E 양자화가 뭐가 다른가

기존 PyTorch 양자화는 크게 eager mode, FX graph mode가 있었고, PT2E는 torch.export 기반으로 그래프를 안정적으로 캡처한 뒤, 그 그래프에 양자화 준비/변환을 적용하는 흐름입니다.

PT2E의 장점은 대략 다음과 같습니다.

  • export 기반이라 그래프가 더 명확하고 재현성이 좋음
  • 런타임/백엔드와 결합되는 경로(특히 edge/모바일/서버 CPU)에 유리
  • 앞으로의 PyTorch 2.x 최적화 스택과 정렬되는 방향

다만, 모델이 torch.export 로 캡처 가능해야 합니다(동적 제어 흐름, 일부 Python side effect 등은 제약이 될 수 있음).

준비물: 버전, 백엔드, 예제 모델

설치

아래는 CPU 중심 예시입니다.

pip install torch==2.1.* torchvision --index-url https://download.pytorch.org/whl/cpu

환경이 꼬여서 pip install 은 성공했는데 실행 시 ModuleNotFoundError 가 나는 경우가 의외로 자주 있습니다. 그때는 인터프리터/가상환경 혼용을 먼저 의심하세요. 체크리스트는 이 글이 매우 실전적입니다: pip install은 성공인데 실행하면 ModuleNotFoundError가 뜰 때...

백엔드 선택 개념

  • 서버/데스크톱 CPU(대부분 x86): x86 엔진을 주로 사용
  • ARM 모바일/라즈베리파이: qnnpack 계열이 흔함

PyTorch에서 엔진은 torch.backends.quantized.engine 로 선택합니다.

import torch

# x86 서버라면 보통 이 선택
torch.backends.quantized.engine = "x86"
print(torch.backends.quantized.engine)

전체 파이프라인 개요

PT2E INT8 PTQ의 큰 흐름은 다음과 같습니다.

  1. FP32 모델 준비 + eval()
  2. 예시 입력(example input) 준비
  3. torch.export.export_for_training 또는 torch.export.export 로 그래프 캡처
  4. Quantizer 설정(XNNPACKQuantizer 또는 X86InductorQuantizer 등)
  5. prepare_pt2e 로 관측자/가짜양자화 노드 삽입
  6. 캘리브레이션 데이터로 prepared 모델을 몇 번 실행
  7. convert_pt2e 로 INT8 모델 생성
  8. 정확도/성능 측정

아래 예제는 “정적 양자화(activation 포함 INT8)”를 목표로 합니다.

실전 예제: 간단한 CNN을 PT2E로 INT8 변환

아래 코드는 개념을 보여주기 위한 최소 예시입니다. 실제로는 ResNet, ViT 등 더 복잡한 모델에도 같은 패턴을 적용합니다(단, export 가능성/연산 지원 여부는 확인 필요).

import torch
import torch.nn as nn
import torch.nn.functional as F

# 1) FP32 모델 정의
class SmallCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
        self.fc = nn.Linear(32 * 8 * 8, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

model_fp32 = SmallCNN().eval()

# 2) 예시 입력
example_inputs = (torch.randn(1, 3, 32, 32),)

1) export로 그래프 캡처

PT2E에서는 torch.export 로 그래프를 캡처합니다.

import torch

# PyTorch 버전에 따라 export API가 조금 다를 수 있습니다.
# 2.1 계열에서는 아래 형태를 많이 사용합니다.
exported = torch.export.export(model_fp32, example_inputs)

2) Quantizer 설정 + prepare

PyTorch 2.1의 PT2E 양자화는 torch.ao.quantization.quantize_pt2e 아래 유틸을 사용합니다.

아래 코드는 x86 INT8을 가정합니다.

import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
    X86InductorQuantizer,
    get_default_x86_inductor_quantization_config,
)

# 3) Quantizer 생성
quantizer = X86InductorQuantizer()
qconfig = get_default_x86_inductor_quantization_config()
quantizer.set_global(qconfig)

# 4) prepare: 관측자/통계 수집 준비
prepared = prepare_pt2e(exported, quantizer)

여기서 prepared 는 아직 INT8이 아닙니다. 캘리브레이션을 통해 activation range를 수집해야 합니다.

3) 캘리브레이션 루프

캘리브레이션은 “대표 데이터 몇 백~몇 천 샘플”을 흘려보내는 과정입니다. 학습은 하지 않고, forward만 반복합니다.

def calibrate(model, data_loader, num_batches=10):
    model.eval()
    with torch.no_grad():
        for i, (x,) in enumerate(data_loader):
            model(x)
            if i + 1 >= num_batches:
                break

# 더미 캘리브레이션 데이터
calib_loader = [(torch.randn(16, 3, 32, 32),) for _ in range(20)]

calibrate(prepared, calib_loader, num_batches=20)

캘리브레이션 데이터 품질이 INT8 정확도를 좌우합니다.

  • 입력 분포가 실제 트래픽과 유사해야 함
  • 너무 적은 배치로 끝내면 activation range가 왜곡될 수 있음
  • outlier가 많으면 스케일이 커져서 정밀도가 떨어질 수 있음

4) convert로 INT8 모델 생성

quantized_model = convert_pt2e(prepared)
quantized_model.eval()

이제 quantized_model 은 INT8 연산을 포함한 그래프가 됩니다(백엔드/연산 지원 여부에 따라 일부는 FP32로 남을 수 있음).

정확도 검증: FP32 vs INT8 출력 비교

정확도 평가는 태스크에 따라 다르지만, 최소한 “출력 분포가 심하게 깨졌는지”를 빠르게 확인할 수 있습니다.

import torch

def compare_outputs(fp32_model, int8_model, x):
    fp32_model.eval()
    int8_model.eval()
    with torch.no_grad():
        y_fp32 = fp32_model(x)
        y_int8 = int8_model(x)

    # 간단 지표: 평균 절대 오차
    mae = (y_fp32 - y_int8).abs().mean().item()
    # 상대적으로 스케일 영향을 줄이려면 cosine similarity도 자주 봅니다.
    cos = torch.nn.functional.cosine_similarity(y_fp32, y_int8, dim=1).mean().item()
    return mae, cos

x = torch.randn(8, 3, 32, 32)
mae, cos = compare_outputs(model_fp32, quantized_model, x)
print("MAE:", mae, "Cosine:", cos)

실서비스에서는 이 수준 비교가 아니라, 원래의 validation set으로 top-1, mAP, WER 같은 태스크 지표를 재측정해야 합니다.

성능 측정: CPU 레이턴시 벤치마크

양자화의 목적은 보통 성능입니다. 아래는 매우 단순한 벤치마크입니다.

import time
import torch

def bench(model, x, iters=200, warmup=50):
    model.eval()
    with torch.no_grad():
        for _ in range(warmup):
            model(x)

        t0 = time.perf_counter()
        for _ in range(iters):
            model(x)
        t1 = time.perf_counter()

    return (t1 - t0) / iters

x = torch.randn(1, 3, 32, 32)

t_fp32 = bench(model_fp32, x)
t_int8 = bench(quantized_model, x)
print("fp32 sec/iter:", t_fp32)
print("int8 sec/iter:", t_int8)
print("speedup:", t_fp32 / t_int8)

주의할 점:

  • CPU 스레드 수(OMP_NUM_THREADS, torch.set_num_threads)에 따라 결과가 크게 바뀜
  • 작은 모델은 오버헤드가 더 커서 speedup이 기대보다 낮을 수 있음
  • 일부 연산이 INT8로 내려가지 못하면 병목이 남음

흔한 함정과 디버깅 포인트

1) export가 실패하는 경우

torch.export 는 Python 제어 흐름, 데이터 의존 분기, 일부 동적 shape 처리에서 막힐 수 있습니다. 해결 방향은 다음입니다.

  • 모델 forward에서 Python side effect 제거(리스트 append, dict mutate 등)
  • 입력 shape를 고정하거나, 지원되는 dynamic shape 옵션을 사용
  • 문제 레이어를 분리해 서브모듈 단위로 export 시도

2) INT8로 안 내려가는 연산이 많은 경우

변환이 되었는데도 성능이 안 나오면, 실제로는 핵심 연산이 FP32로 남아 있을 수 있습니다.

  • 지원되는 op 조합인지 확인(Conv+ReLU, Linear 등은 상대적으로 잘 됨)
  • normalization, 일부 activation, reshape 패턴 등에서 끊길 수 있음
  • 모델을 약간 재구성해 fusion-friendly하게 만들면 개선됨

3) 캘리브레이션 데이터가 부정확한 경우

INT8 PTQ에서 가장 흔한 정확도 하락 원인입니다.

  • 트래픽 대표성이 없는 랜덤/편향 데이터로 캘리브레이션
  • 배치 수가 너무 적음
  • 입력 전처리(정규화, 리사이즈)가 실제와 다름

캘리브레이션 파이프라인을 운영 환경과 동일하게 맞추는 것이 중요합니다. 운영에서 병목이나 오류를 추적하는 습관은 다른 영역에서도 동일하게 적용됩니다. 예를 들어 장애 상황에서 재시도/백오프를 체계적으로 설계하는 접근은 Claude API 529 Overloaded 재시도·백오프 설계 글이 참고가 됩니다.

4) 정확도는 괜찮은데 성능이 안 나오는 경우

가능성이 큰 원인들:

  • 스레딩/NUMA/CPU governor 영향
  • 작은 배치에서 양자화 오버헤드가 상대적으로 큼
  • 모델이 memory-bound라 INT8 이득이 제한적

이때는 단순 평균 레이턴시뿐 아니라, p50/p95, 스루풋(QPS), CPU 사용률을 함께 봐야 합니다.

실무 팁: 언제 PTQ 대신 QAT를 고려하나

PTQ(INT8)로도 충분한 경우가 많지만, 아래 상황이면 QAT(Quantization-Aware Training)를 고려합니다.

  • 정확도 하락이 임계치 이상
  • activation 분포가 매우 예민한 모델(예: 일부 transformer 변형)
  • outlier가 많아 PTQ 스케일링이 불리한 경우

다만 QAT는 학습 비용이 추가되므로, 먼저 PTQ로 빠르게 실험하고 “안 되면 QAT” 순서가 현실적입니다.

체크리스트: INT8 PT2E 적용 전후로 꼭 확인할 것

  • model.eval() 상태에서 export/prepare/convert 수행했는가
  • torch.backends.quantized.engine 이 타깃 CPU에 맞는가
  • 캘리브레이션 데이터가 실제 입력 분포를 대표하는가
  • 변환 후 태스크 지표(정확도)와 레이턴시를 모두 측정했는가
  • 핵심 연산이 실제로 INT8로 내려갔는지(부분 FP32 fallback 여부) 확인했는가

마무리

PyTorch 2.1의 PT2E는 “export 기반 그래프 + 양자화”라는 흐름으로, 앞으로의 PyTorch 최적화 스택과 결이 맞습니다. INT8 PTQ는 구현 난이도 대비 효과가 큰 편이지만, 캘리브레이션 품질연산 지원 범위가 결과를 좌우합니다.

다음 단계로는 (1) 실제 모델(ResNet, BERT 계열)에 적용, (2) 캘리브레이션 샘플링 전략 개선, (3) 성능 프로파일링으로 FP32 fallback 구간 제거, (4) 필요 시 QAT 전환까지 이어가면 됩니다.