Published on

PyTorch 2.0 PT2E+QAT로 INT8 양자화 실전

Authors

PyTorch 2.0 이후 양자화는 기존 FX Graph Mode를 넘어, torch.export를 중심으로 한 PT2E(PyTorch 2 Export) 파이프라인으로 빠르게 이동하고 있습니다. 특히 QAT(Quantization Aware Training) 는 단순 PTQ(Post Training Quantization)보다 정확도 손실을 줄이기 쉬워서, CPU 추론(특히 x86)에서 INT8 성능을 노릴 때 실무에서 자주 선택합니다.

이 글에서는 PT2E + QAT로 INT8 모델을 만드는 전체 흐름을 정리합니다. 핵심은 다음 4단계입니다.

  1. 모델을 torch.export내보낼 수 있는 형태로 정리
  2. prepare_qat_pt2eFakeQuant/Observer 삽입
  3. QAT 학습(또는 미세조정) 수행
  4. convert_pt2e실제 INT8 연산으로 변환 후 검증

추가로, 배포 파이프라인에서 흔히 부딪히는 “왜 INT8로 바뀌지 않았나”, “성능이 안 나온다”, “정확도가 급락한다” 같은 함정도 함께 다룹니다.

관련해서 ONNX 및 TensorRT 기반 INT8로 이어가려면 아래 글도 참고하면 연결이 매끄럽습니다.


PT2E 양자화가 필요한 이유: FX 대비 달라진 점

기존 FX Graph Mode 양자화는 torch.fx 기반으로 그래프를 캡처하고 변환했습니다. 반면 PT2E는 torch.export가 만들어내는 AOT( ahead-of-time ) 그래프를 기반으로 양자화가 진행됩니다.

PT2E 기반의 장점은 다음과 같습니다.

  • torch.compile/AOTAutograd 등 PyTorch 2 계열 최적화와 결합이 쉬움
  • 내보낸 그래프가 더 “고정된 형태”라서 변환 파이프라인을 안정적으로 구성 가능
  • 백엔드(예: x86, fbgemm, qnnpack)별 제약을 더 명시적으로 다루기 쉬움

단, torch.export가 요구하는 제약(데이터 의존 제어 흐름, 일부 동적 shape 등)을 만족해야 하므로, 모델을 export-friendly하게 만드는 작업이 필요할 수 있습니다.


준비: 버전, 백엔드, 대상 연산 확인

권장 환경

  • PyTorch 2.0 이상(가능하면 2.1+에서 PT2E 양자화 안정성이 더 좋아짐)
  • CPU INT8 타깃
    • x86 서버면 보통 fbgemm 경로를 기대
    • ARM 모바일/라즈베리파이면 qnnpack 경로를 기대

아래 코드는 환경 확인용입니다.

import torch

print(torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Quantized engine:", torch.backends.quantized.engine)

torch.backends.quantized.engine가 기대한 엔진이 아니면 성능/지원 op가 달라질 수 있습니다. 예를 들어 x86에서 fbgemm이 아닌 경우, INT8 최적화가 제한될 수 있습니다.


전체 파이프라인 개요(PT2E + QAT)

PT2E QAT의 큰 흐름은 아래와 같습니다.

  • torch.export.exportExportedProgram 생성
  • Quantizer 설정(예: X86Quantizer) 및 prepare_qat_pt2e
  • QAT 학습
  • convert_pt2e로 INT8 변환
  • 정확도/성능 검증

여기서 가장 중요한 포인트는 “준비(prepare) 이후에는 모델 그래프에 FakeQuant가 들어간다” 는 점입니다. 학습은 FP32처럼 보이지만, 내부적으로는 양자화 효과를 모사하면서 가중치/활성 분포가 INT8에 맞춰집니다.


예제 모델: Conv-BN-ReLU 블록(대표적인 QAT 타깃)

QAT는 Conv/Linear 계열에서 효과가 큰 편입니다. 예제로 간단한 CNN을 준비합니다.

import torch
import torch.nn as nn
import torch.nn.functional as F

class SmallCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc = nn.Linear(64 * 8 * 8, num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

model = SmallCNN().eval()
example_inputs = (torch.randn(1, 3, 32, 32),)

실무에서는 ResNet, MobileNet 계열로 확장하되, Conv-BN-ReLU 패턴이 제대로 fusion/패턴 인식되는지가 INT8 품질과 성능에 큰 영향을 줍니다.


1) torch.export로 ExportedProgram 만들기

PT2E는 먼저 모델을 export합니다.

import torch

model = model.eval()
exported = torch.export.export(model, example_inputs)
print(type(exported))

export 단계에서 자주 깨지는 지점

  • 입력 shape이 완전히 고정되어야 하는 경우
  • forward에 Python 제어 흐름이 강하게 섞인 경우
  • 일부 커스텀 op가 export를 지원하지 않는 경우

이 경우에는 모델을 단순화하거나, 입력을 고정하거나, 가능하면 표준 연산으로 치환하는 식으로 해결합니다.


2) Quantizer 설정과 QAT 준비(prepare)

PT2E 양자화는 “어떤 op를 어떤 방식으로 양자화할지”를 Quantizer로 정의합니다. x86 CPU 타깃이면 보통 X86Quantizer를 사용합니다.

아래 코드는 PT2E QAT 준비의 전형적인 형태입니다(버전에 따라 import 경로/이름이 조금씩 다를 수 있습니다).

import torch

from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.ao.quantization.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config

quantizer = X86InductorQuantizer()
quantizer.set_global(get_default_x86_inductor_quantization_config())

prepared = prepare_qat_pt2e(exported, quantizer)

여기서 무엇이 일어나나

  • Conv/Linear 등에 FakeQuant 노드가 삽입
  • 활성(activation)과 가중치(weight)에 대해 관측/클리핑 스케일을 학습 가능한 형태로 구성
  • 이후 학습에서 INT8로 변환될 분포를 “미리 맞춰가는” 효과가 생김

QAT 준비 시 체크 포인트

  • 모든 레이어가 양자화되는 것이 목표가 아니라, 성능 이득이 큰 구간만 양자화되는 것이 일반적
  • BatchNorm이 포함된 패턴은 변환 단계에서 fusion/패턴 매칭이 중요

3) QAT 학습(또는 미세조정) 루프

QAT는 보통 “처음부터 학습”보다는 FP32로 학습된 체크포인트를 가져와 짧게 미세조정하는 방식이 비용 대비 효율이 좋습니다.

아래는 최소 예시입니다.

import torch
import torch.nn as nn

prepared.train()
optimizer = torch.optim.Adam(prepared.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for step in range(200):
    inputs = torch.randn(32, 3, 32, 32)
    targets = torch.randint(0, 10, (32,))

    logits = prepared(inputs)
    loss = criterion(logits, targets)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if step % 50 == 0:
        print(step, float(loss))

실무 팁: QAT 안정화

  • 학습률을 FP32 학습 대비 낮게 시작
  • 초반 몇 epoch는 observer 통계를 안정화시키고, 이후 fake quant 강도를 점진적으로 적용하는 전략을 고려
  • 정확도 급락 시, 양자화 범위를 줄이거나(일부 레이어 제외), per-channel weight quant 옵션을 점검

4) INT8로 변환(convert)하고 검증하기

학습이 끝나면 convert_pt2e로 실제 INT8 연산 그래프로 변환합니다.

prepared.eval()
quantized = convert_pt2e(prepared)

이후에는 다음을 확인해야 합니다.

  • 그래프에 quantize/dequantize 패턴이 생겼는지
  • 기대한 op가 INT8 커널로 내려갔는지
  • 정확도/출력이 허용 오차 내인지

간단한 출력 비교

model_fp32 = model.eval()

x = torch.randn(8, 3, 32, 32)
with torch.no_grad():
    y_fp32 = model_fp32(x)
    y_int8 = quantized(x)

print("fp32 mean:", y_fp32.mean().item())
print("int8 mean:", y_int8.mean().item())
print("max abs diff:", (y_fp32 - y_int8).abs().max().item())

정확도 평가는 반드시 실제 검증 데이터셋으로 해야 합니다. 위 비교는 “변환이 터지지 않았는지” 수준의 스모크 테스트로만 사용하세요.


성능이 안 나오는 흔한 원인 6가지

1) 실제로는 INT8 커널을 못 타는 경우

그래프가 INT8로 변환되었어도, 특정 op 조합이 백엔드에서 지원되지 않으면 내부적으로 FP32로 fallback될 수 있습니다. 이 경우 “INT8인데 빨라지지 않음”이 발생합니다.

대응:

  • 대상 백엔드(x86, qnnpack 등)에서 지원되는 패턴으로 모델을 구성
  • Conv-BN-ReLU처럼 잘 알려진 패턴을 유지
  • 불필요한 view/permute/reshape가 연산 사이에 끼어 패턴 매칭을 깨는지 확인

2) 작은 배치, 작은 모델이라 오버헤드가 더 큰 경우

INT8은 커널 자체는 빠르지만, QDQ(quantize/dequantize) 오버헤드나 메모리 이동 비용 때문에 작은 워크로드에서는 체감이 약할 수 있습니다.

대응:

  • 배치/입력 크기에서 실제 서비스 조건으로 벤치마크
  • 가능하면 연산을 fuse할 수 있는 구조로 설계

3) calibration/observer 통계가 불안정한 경우

QAT라도 데이터 분포가 불안정하거나 학습이 너무 짧으면 스케일이 흔들려 정확도가 떨어집니다.

대응:

  • 대표 데이터로 충분히 미세조정
  • 데이터 전처리(정규화)가 학습/추론에서 완전히 동일한지 확인

4) 레이어 일부만 양자화되어 병목이 남는 경우

가장 무거운 블록이 양자화 대상에서 빠져 있으면 전체 성능 개선이 제한됩니다.

대응:

  • 병목 레이어를 프로파일링으로 찾고, 해당 레이어가 양자화되는지 확인

5) 연산자 교체로 export/quant 패턴이 깨지는 경우

예를 들어 activation을 커스텀으로 바꾸거나, BN을 forward에서 직접 계산하는 형태로 바꾸면 패턴 매칭이 어려워질 수 있습니다.

대응:

  • 표준 nn.ReLU, nn.BatchNorm2d를 유지
  • 가능하면 torch.nn.functional보다 모듈 형태를 유지(패턴 인식에 유리한 경우가 있음)

6) 멀티스레딩/런타임 설정 문제

CPU 추론 성능은 OMP_NUM_THREADS, MKL_NUM_THREADS 같은 런타임 설정에 크게 좌우됩니다.

대응:

  • 서비스 환경에서 스레드 설정을 고정하고 측정
  • 컨테이너 환경에서는 CPU quota와 affinity도 점검

운영 환경에서 프로세스가 계속 재시작되거나 자원이 제한되는 이슈가 섞이면 성능 측정 자체가 왜곡될 수 있습니다. 그런 경우엔 아래 체크리스트처럼 “프로세스/서비스 상태”부터 안정화하는 것이 우선입니다.


배포 관점: 저장, 로딩, 재현성

PT2E/Export 기반은 “그래프가 고정된 형태”라는 장점이 있지만, 저장 포맷과 로딩 전략을 팀 표준으로 잡아두는 게 중요합니다.

일반적으로는 다음을 권합니다.

  • 학습 산출물: FP32 체크포인트 + QAT 설정(quant config) 버전 관리
  • 배포 산출물: 변환된 INT8 모델(ExportedProgram 또는 패키징된 아티팩트)

또한 CI/CD에서 모델 아티팩트를 내려받아 컨테이너로 배포하는 과정에서 권한 문제로 pull이 실패하면(예: ECR 403) 모델 검증 파이프라인이 중단됩니다. 모델 최적화 자체와 별개로 배포 경로도 안정화해두는 게 좋습니다.


실무 체크리스트(요약)

  • torch.export가 안정적으로 되는 모델 구조인가
  • 타깃 백엔드에서 지원되는 양자화 패턴(Conv/Linear 중심)인가
  • prepare_qat_pt2e 이후 QAT 학습을 충분히 수행했는가
  • convert_pt2e 후 실제 INT8 커널로 실행되는가(프로파일링/그래프 확인)
  • 서비스 조건에서 성능을 측정했는가(배치, 스레드, CPU quota)

마무리

PT2E + QAT는 “PyTorch 2 시대의 정석”에 가까운 INT8 양자화 루트입니다. 핵심은 단순히 INT8로 변환하는 것이 아니라, export 가능한 그래프 구성백엔드가 좋아하는 패턴 유지, 그리고 QAT 미세조정으로 분포를 INT8에 맞추는 과정입니다.

다음 단계로는 아래 확장을 고려할 수 있습니다.

  • 더 큰 모델(ResNet, MobileNet)로 확장하며 레이어별 양자화 범위 튜닝
  • ONNX로 내보내 TensorRT INT8 파이프라인으로 연결
  • 프로파일링 기반으로 병목 블록을 찾아 “양자화가 성능을 내는 구조”로 리팩터링

원하면 사용 중인 모델 아키텍처(예: ResNet50, ViT, UNet), 타깃 환경(x86 서버, Graviton, 모바일)과 함께 현재 정확도/지연시간을 알려주면, PT2E QAT 설정을 어떤 방향으로 잡는 게 좋은지(양자화 제외 레이어, per-channel 여부, 학습 스케줄)까지 구체적으로 제안할 수 있습니다.