- Published on
PyTorch 2.3 PT2E+XNNPACK로 INT8 양자화
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서버에서 CPU 추론 비용을 줄이려면 결국 INT8로 내려가야 하는 순간이 옵니다. 특히 모바일·엣지뿐 아니라, 클라우드에서도 CPU 스케일아웃 비용을 줄이거나 GPU가 없는 환경에서 지연 시간을 맞추려면 INT8은 여전히 강력한 카드입니다.
PyTorch 2.x에서는 기존 FX 기반 양자화 흐름과 더불어, torch.export를 중심으로 한 PT2E(Post Training 2 Export) 스타일의 파이프라인이 자리 잡았습니다. 여기에 XNNPACK 백엔드를 붙이면 ARM/모바일 계열에서 특히 좋은 성능을 기대할 수 있고, x86에서도 상황에 따라 이점을 얻을 수 있습니다.
이 글은 PyTorch 2.3 기준으로 PT2E+XNNPACK INT8 양자화를 “실제로 돌아가게 만드는” 데 초점을 맞춥니다. PyTorch 2.0 시절 실전 흐름을 먼저 보고 싶다면 이전 글인 PyTorch 2.0 PT2E+XNNPACK int8 양자화 실전도 같이 참고하면 비교가 쉽습니다.
왜 PyTorch 2.3에서 PT2E인가
PyTorch 2.3에서 PT2E 접근을 쓰는 이유는 크게 3가지입니다.
- 그래프 안정성:
torch.export가 만들어내는 ExportedProgram은 추론 그래프를 더 엄격하게 고정합니다. 동적 제어 흐름이나 파이썬 레벨 의존성이 줄어들어, 양자화 변환이 예측 가능해집니다. - 백엔드 지향: 양자화는 결국 특정 커널(XNNPACK, QNNPACK, FBGEMM 등)에 맞춰 연산을 재작성하는 작업입니다. Export 기반 흐름은 “백엔드가 원하고, 런타임이 실행 가능한” 형태로 정리하기 좋습니다.
- AOT/컴파일 파이프라인과의 결합:
torch.compile및 AOT 계열과 섞는 시나리오에서, Export는 중간 표현으로서 쓸모가 큽니다.
다만, PT2E는 “편한 마법”이라기보다 파이프라인을 정확히 이해해야 원하는 결과가 나오는 도구에 가깝습니다. 아래에서 준비부터 디버깅까지 단계별로 정리합니다.
준비물: 버전, 백엔드, 주의사항
권장 환경
- Python 3.10 이상 권장
- PyTorch 2.3.x
- torchvision은 모델에 따라 필요
- 모바일/ARM에서 XNNPACK 이점이 큼
설치 예시는 다음처럼 진행할 수 있습니다.
pip install torch==2.3.* torchvision --index-url https://download.pytorch.org/whl/cpu
GPU 환경에서도 양자화 자체는 가능하지만, 여기서는 CPU 추론용 INT8을 목표로 합니다.
XNNPACK과 엔진 선택
PyTorch의 양자화 백엔드 선택은 보통 torch.backends.quantized.engine로 확인합니다. 환경에 따라 기본값이 다를 수 있어, 의도한 엔진으로 고정하는 습관이 좋습니다.
import torch
print(torch.backends.quantized.supported_engines)
print(torch.backends.quantized.engine)
# 필요 시 고정
# torch.backends.quantized.engine = "xnnpack"
주의할 점은, “엔진을 xnnpack으로 설정했다”가 “PT2E 변환 결과가 XNNPACK 커널로 최적 실행된다”를 100% 보장하진 않는다는 겁니다. 연산 패턴(예: Conv+ReLU fusion 가능 여부), 지원 dtype/shape, 관측기 설정 등에 따라 결과가 달라질 수 있습니다.
전체 파이프라인 개요
PT2E INT8 양자화의 큰 흐름은 다음 4단계로 이해하면 편합니다.
- Export:
torch.export.export로 추론 그래프를 고정 - Prepare: 관측기(Observer) 삽입, 캘리브레이션 준비
- Calibrate: 대표 데이터로 몇 배치 흘려서 activation 통계 수집
- Convert: INT8로 변환 후 실행/검증
여기서 가장 중요한 건 대표 데이터(캘리브레이션 데이터)의 품질입니다. 분포가 어긋나면 INT8에서 정확도가 급격히 떨어질 수 있습니다.
예제 모델과 데이터: 최소 재현 코드
실전에서 가장 흔한 문제는 “내 모델은 커스텀 모듈이 많고, export가 안 된다”입니다. 그래서 먼저 export가 잘 되는 형태로 최소 예제를 만들고, 그 다음 실제 모델로 확장하는 접근을 추천합니다.
아래는 Conv 기반의 간단한 이미지 모델 예시입니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SmallConv(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 3, stride=2, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
self.fc = nn.Linear(32, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.mean(dim=(2, 3))
x = self.fc(x)
return x
model = SmallConv().eval()
example = torch.randn(1, 3, 224, 224)
1) Export: torch.export.export로 그래프 고정
torch.export는 입력의 shape, dtype, 제약 조건을 기반으로 그래프를 고정합니다. 가장 단순한 시작은 example input으로 export하는 것입니다.
import torch
ep = torch.export.export(model, (example,))
print(type(ep))
여기서 실패하는 대표 케이스는 다음과 같습니다.
- forward 내부에 파이썬 if/for가 있고 입력 값에 따라 제어 흐름이 바뀜
- 텐서 shape를 파이썬 정수로 뽑아 리스트 인덱싱 등에 사용
- 일부 연산이 export/ATen 그래프에서 지원되지 않음
이런 경우엔 모델을 “export-friendly”하게 바꾸거나, 동적 shape 제약을 명시해야 합니다. 먼저는 정적 shape로 확정하고 성공시키는 것이 우선입니다.
2) Prepare: 관측기 삽입과 QConfig 선택
PTQ(Post Training Quantization)에서는 activation 통계를 모아 scale/zero-point를 정합니다. PyTorch는 이를 위해 Observer를 삽입합니다.
PyTorch 2.3의 세부 API는 릴리즈마다 조금씩 달라질 수 있지만, 핵심은 다음입니다.
- 백엔드에 맞는
QConfig를 선택 - ExportedProgram에 관측기를 삽입
아래 코드는 “개념을 유지하면서” 현재 PyTorch 양자화 API 흐름에 맞춘 형태의 예시입니다. 실제 프로젝트에서는 사용 중인 torch 버전에 맞춰 torch.ao.quantization 및 torch.ao.quantization.quantize_pt2e 계열 함수명을 확인하세요.
import torch
from torch.ao.quantization import get_default_qconfig
# XNNPACK용 qconfig 선택
qconfig = get_default_qconfig("xnnpack")
# 관측기 삽입(prepare 단계)
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
prepared = prepare_pt2e(ep, qconfig)
만약 여기서 에러가 난다면 대개 다음 원인입니다.
- 특정 op가 정량화 패턴으로 매칭되지 않음
- 준비 단계가 예상한 그래프 형태(예: Conv+ReLU)로 나오지 않음
- 모델에 이미
QuantStub같은 구식 스텁이 섞여 충돌
PT2E에서는 가급적 모델을 float로 깨끗하게 유지하고, prepare/convert에서만 변환을 적용하는 편이 문제를 줄입니다.
3) Calibrate: 대표 데이터로 통계 수집
캘리브레이션은 “학습”이 아니라 “통계 수집”입니다. 그렇지만 결과 정확도를 좌우하므로 절대 대충 하면 안 됩니다.
- 데이터는 실제 추론 입력 분포를 대표해야 함
- 최소 수십~수백 배치 권장(모델/도메인에 따라 다름)
- augmentation은 보통 끄는 편이 안전(실추론 분포를 따르기)
예시는 랜덤 텐서로 돌리지만, 실제로는 DataLoader를 사용하세요.
prepared.eval()
with torch.inference_mode():
for _ in range(200):
x = torch.randn(1, 3, 224, 224)
_ = prepared.module()(x) if hasattr(prepared, "module") else prepared(x)
여기서 prepared(x) 호출 방식은 객체 타입에 따라 달라질 수 있습니다. ExportedProgram 래핑 형태에 따라 prepared가 callable이거나, 내부에 실제 callable이 들어있을 수 있습니다. 프로젝트에서는 print(type(prepared))로 확인하고 호출 방식을 고정하세요.
4) Convert: INT8 모델로 변환
통계를 모았다면 convert를 수행합니다.
from torch.ao.quantization.quantize_pt2e import convert_pt2e
quantized = convert_pt2e(prepared)
변환 후에는 다음을 반드시 확인합니다.
- 출력 shape 동일
- 수치 오차(정확도) 허용 범위 내
- 지연 시간 및 CPU 사용률 개선
간단한 수치 비교 예시는 다음과 같습니다.
model.eval()
with torch.inference_mode():
x = torch.randn(1, 3, 224, 224)
y_fp32 = model(x)
y_int8 = quantized(x)
diff = (y_fp32 - y_int8).abs().mean().item()
print("mean abs diff:", diff)
분류 모델이라면 Top-1 정확도를 비교하거나, 회귀라면 MAE/MSE를 비교하는 식으로 “업무 지표” 기준으로 검증해야 합니다.
성능 측정: 기대한 만큼 빨라졌는지 확인
INT8로 바꿨는데도 빨라지지 않는 경우가 꽤 흔합니다. 원인은 크게 4가지입니다.
- 지원되지 않는 op가 많아 float fallback이 많음
- 배치/shape가 XNNPACK에 불리
- 스레드 설정 미스
- 측정 방법이 잘못됨(워밍업 없음, 동기화 이슈 등)
간단한 벤치 코드는 다음처럼 작성할 수 있습니다.
import time
import torch
def bench(fn, x, iters=200, warmup=50):
fn.eval()
with torch.inference_mode():
for _ in range(warmup):
_ = fn(x)
t0 = time.perf_counter()
for _ in range(iters):
_ = fn(x)
t1 = time.perf_counter()
return (t1 - t0) / iters
x = torch.randn(1, 3, 224, 224)
fp32_t = bench(model, x)
int8_t = bench(quantized, x)
print("fp32 sec/iter:", fp32_t)
print("int8 sec/iter:", int8_t)
또한 CPU 스레드 수는 결과에 큰 영향을 줍니다.
import torch
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
실서비스 환경(컨테이너 CPU quota, NUMA, 스레드 제한)과 동일하게 맞춘 뒤 측정해야 합니다.
정확도 하락을 줄이는 실전 팁
1) 캘리브레이션 데이터 분포를 맞춰라
INT8에서 activation clipping이 발생하면 정보 손실이 커집니다. 특히 입력 스케일이 큰 도메인(예: HDR 이미지 전처리, 오디오 스펙트럼 등)에서는 캘리브레이션 데이터가 조금만 어긋나도 정확도가 크게 떨어질 수 있습니다.
- 전처리(정규화, 리사이즈, 컬러 스페이스)를 실추론과 동일하게
- outlier가 많은 데이터라면 percentile 기반 observer를 고려
2) 레이어별 민감도 분석
정확도가 많이 떨어지면 “어떤 블록이 문제인지”부터 찾아야 합니다.
- 특정 블록만 float로 유지(부분 양자화)
- 마지막 FC나 LayerNorm 계열은 float 유지가 유리한 경우가 있음
PT2E에서도 결국은 “어떤 op가 quantize/dequantize 경계를 갖는지”가 중요합니다.
3) 연산 패턴을 단순화하라
양자화는 패턴 매칭이 잘 되는 구조가 유리합니다.
- Conv 다음에 즉시 ReLU가 오도록(가능한 경우)
- 불필요한 view/permute 남발을 줄이기
- activation 함수가 특이하면 대체 가능성 검토
흔한 에러와 디버깅 체크리스트
Export가 안 될 때
- forward에서 텐서 값을 파이썬 분기로 쓰는지 확인
- 입력 shape를 고정하고 먼저 성공시키기
- 커스텀 op가 있으면 ATen으로 내릴 수 있는지 검토
Convert 후 속도가 안 나올 때
- quantized 그래프에 float fallback이 많은지 확인
- 모델의 주요 연산(Conv/GEMM)이 실제로 INT8 커널로 내려갔는지 확인
- 스레드/배치/입력 크기 튜닝
정확도가 크게 떨어질 때
- 캘리브레이션 배치 수 증가
- 캘리브레이션 데이터 분포 재점검
- 부분 양자화로 민감 레이어를 float 유지
운영 환경에서 문제가 생기면, 원인 진단을 체계적으로 하는 게 시간을 줄입니다. 장애/성능 이슈를 빠르게 좁혀가는 접근은 인프라 쪽에서도 동일합니다. 예를 들어 서비스가 비정상 재시작을 반복할 때 원인을 단계적으로 좁히는 방식은 systemd 서비스 자동 재시작 무한루프 진단 가이드처럼 “가설-검증” 루틴이 핵심입니다.
PyTorch 2.3에서의 운영 적용 포인트
모델 버전 관리와 재현성
양자화는 “데이터 통계”에 의존하므로, 다음을 같이 버전 관리하는 편이 안전합니다.
- torch 버전, quantization 관련 설정
- 캘리브레이션 데이터 샘플링 기준
- 입력 전처리 코드
프로덕션에서 재현성이 흔들리면, 같은 모델인데도 배포마다 정확도/지연이 요동칠 수 있습니다.
실패를 정상 흐름으로 다루기
실서비스에서는 “양자화 변환이 항상 성공한다”를 가정하면 위험합니다. 특정 모델/버전에서 export나 convert가 깨질 수 있고, 이때 FP32로 폴백하는 전략이 필요합니다. 이런 식의 실패를 값으로 다루는 패턴은 예외 기반보다 운영에 유리한 면이 있습니다. 관심 있다면 C++23 std expected로 예외 없이 안전한 자원관리처럼 실패를 명시적으로 모델링하는 접근을 참고해도 좋습니다.
마무리: PT2E+XNNPACK INT8의 현실적인 기대치
PyTorch 2.3의 PT2E 기반 양자화는 “최신 파이프라인”이지만, 여전히 모델 구조와 연산 패턴, 캘리브레이션 품질에 성패가 갈립니다. 성공하면 CPU 추론에서 지연 시간과 비용을 동시에 줄일 수 있고, 특히 모바일/ARM 계열에서는 XNNPACK 조합이 실전 효용이 큽니다.
추천하는 진행 순서는 다음입니다.
- 가장 작은 서브모델로 export/prepare/convert 성공
- 캘리브레이션 데이터 품질을 먼저 확보
- 정확도 하락이 크면 부분 양자화로 타협점 찾기
- 마지막에 성능 측정과 스레드/배치 튜닝
다음 단계로는, 실제 서비스 입력 분포를 반영한 캘리브레이션 파이프라인 자동화와, 모델 변경 시 양자화 회귀 테스트(정확도/지연)를 CI에 넣는 것을 권장합니다.