- Published on
PyTorch 2.0+ PTQ로 INT8 변환해 3배 가속하기
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙 환경에서 추론 지연 시간을 줄이는 가장 현실적인 방법 중 하나가 INT8 양자화입니다. 특히 GPU가 없거나, CPU 인스턴스 비용이 민감한 환경에서는 FP32 모델을 그대로 돌리기보다 PTQ(Post Training Quantization)로 INT8로 내리는 것만으로도 체감 성능이 크게 좋아집니다.
PyTorch 2.0 이후에는 torch.compile과 TorchInductor 최적화가 들어오면서 “컴파일로 빨라지니 양자화는 필요 없지 않나”라는 질문이 종종 나오는데, CPU 추론에서는 여전히 INT8가 강력합니다. 이유는 단순합니다. 메모리 대역폭과 캐시 효율, 그리고 oneDNN/FBGEMM 기반의 INT8 GEMM이 FP32 대비 훨씬 유리하기 때문입니다.
이 글에서는 PyTorch 2.0+ 기준으로 PTQ로 INT8 변환하고, 실제로 3배 수준 가속을 노리는 흐름을 “실행 가능한 코드” 중심으로 정리합니다.
- 대상: 주로 CPU 추론(
x86)에서 효과가 큼 - 접근: 학습 없이 PTQ로 변환
- 포인트: 캘리브레이션 데이터 품질, 연산자 지원, 측정 방법
참고로, 최적화는 종종 “캐시가 안 먹어서” 성능이 들쭉날쭉해 보일 때가 있습니다. 빌드/배포 파이프라인에서 캐시 관련 이슈를 겪는다면 GitHub Actions 캐시가 안 먹힐 때 원인 9가지도 같이 보면 도움이 됩니다.
PTQ와 QAT 차이, 그리고 PyTorch 2.0+에서의 선택
PTQ: 학습 없이, 대표 입력(캘리브레이션)으로 activation 범위를 추정해INT8로 변환QAT: 학습 과정에 fake-quant를 넣어 양자화 오차를 학습으로 보정
실무에서는 “일단 PTQ로 성능 이득이 충분한지”를 먼저 봅니다. PTQ로 정확도 하락이 크거나 특정 레이어가 민감하면 그때 QAT로 넘어가는 것이 비용 대비 효율적입니다.
PyTorch 2.0+에서 양자화는 크게 두 축으로 볼 수 있습니다.
torch.ao.quantization(Eager/FX Graph Mode 기반)torch.export/torch.compile기반의 최신 흐름(버전에 따라 기능 성숙도 차이)
이 글에서는 가장 범용적이고 재현성 좋은 FX Graph Mode PTQ를 기준으로 설명합니다.
기대 성능: “3배”가 가능한 조건
“INT8로 바꾸면 무조건 3배?”는 아닙니다. 아래 조건이 맞으면 2배에서 4배까지도 꽤 자주 나옵니다.
- 병목이
Linear/GEMM(Transformer의FFN, MLP, 추천 모델 등) 위주 - 배치가 너무 작지 않음(완전
batch=1에서도 이득은 있지만 모델/CPU에 따라 편차) - CPU가
AVX2/VNNI등INT8에 유리한 명령어 지원 - 양자화 지원 연산자 비중이 높음(unsupported op가 많으면 dequant-quant 왕복으로 손해)
반대로 아래면 기대치가 낮아집니다.
- Conv 위주인데 backend 최적화가 덜 맞는 경우
- 전처리/후처리 Python 코드가 전체 latency의 큰 비중을 차지
- 모델에
LayerNorm,Softmax등 양자화가 까다로운 연산이 많고 그래프 분할이 잦음
환경 준비: 버전과 백엔드 확인
CPU INT8 PTQ에서 가장 먼저 확인할 것은 “어떤 quantization backend를 쓰는지”입니다.
x86서버: 보통fbgemmARM계열: 보통qnnpack
아래 코드는 현재 환경에서 설정 가능한 엔진을 확인하고, fbgemm를 선택합니다.
import torch
print(torch.__version__)
print(torch.backends.quantized.supported_engines)
torch.backends.quantized.engine = "fbgemm"
print("engine:", torch.backends.quantized.engine)
만약 supported_engines에 fbgemm가 없다면, CPU/빌드 옵션/플랫폼 문제일 수 있습니다. 이 경우 같은 코드라도 성능이 예상보다 안 나오거나, 특정 변환이 실패할 수 있습니다.
실전: FX Graph Mode PTQ로 INT8 변환
아래는 “FP32 모델”을 PTQ로 INT8 변환하는 전형적인 흐름입니다.
핵심 단계는 4개입니다.
- 모델을
eval()로 전환 qconfig설정prepare_fx로 관측자(observer) 삽입- 캘리브레이션 데이터를 몇 배치 흘려보낸 뒤
convert_fx
예제 모델
실제 서비스 모델이 Transformer든 CNN이든 과정은 같습니다. 여기서는 간단히 MLP를 예시로 듭니다.
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, in_dim=1024, hidden=4096, out_dim=1000):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, out_dim),
)
def forward(self, x):
return self.net(x)
fp32_model = MLP().eval()
FX PTQ 코드
import torch
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
torch.backends.quantized.engine = "fbgemm"
# 예시 입력 스펙(실제 서비스 입력 shape에 맞추는 것이 중요)
example_inputs = (torch.randn(32, 1024),)
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
prepared = prepare_fx(fp32_model, qconfig_dict, example_inputs)
# 캘리브레이션: 실제 데이터 분포를 대표하는 입력을 흘려보내기
with torch.inference_mode():
for _ in range(50):
x = torch.randn(32, 1024)
prepared(x)
int8_model = convert_fx(prepared)
int8_model.eval()
여기서 캘리브레이션 루프의 torch.randn은 “동작 예시”일 뿐입니다. 정확도 하락을 줄이려면 반드시 실제 트래픽/검증 데이터에서 샘플링한 입력을 사용하세요.
성능 측정: 워밍업과 타이머, 스레드 고정
“3배 가속”을 말하려면 측정이 정확해야 합니다. CPU 추론은 워밍업, 스레드 수, 터보 부스트, NUMA, 메모리 배치에 크게 흔들립니다.
아래는 최소한의 공정한 측정 템플릿입니다.
torch.inference_mode()사용- 워밍업 반복
- 충분한 반복으로 평균 계산
torch.set_num_threads로 스레드 수 고정
import time
import torch
def benchmark(model, x, warmup=50, iters=200):
model.eval()
with torch.inference_mode():
for _ in range(warmup):
_ = model(x)
t0 = time.perf_counter()
for _ in range(iters):
_ = model(x)
t1 = time.perf_counter()
return (t1 - t0) / iters
# 스레드 수는 환경에 맞게 조정(예: 물리 코어 수)
torch.set_num_threads(8)
test_x = torch.randn(32, 1024)
fp32_t = benchmark(fp32_model, test_x)
int8_t = benchmark(int8_model, test_x)
print("fp32 sec/iter:", fp32_t)
print("int8 sec/iter:", int8_t)
print("speedup:", fp32_t / int8_t)
측정 결과가 들쭉날쭉하면 아래를 점검하세요.
- 워밍업 부족(특히
torch.compile을 섞으면 더 필요) - 스레드 수가 OS 스케줄링에 의해 흔들림
- 컨테이너 환경에서 CPU quota 제한
- 동일 머신에서 다른 워크로드 간섭
정확도 하락을 줄이는 캘리브레이션 전략
PTQ의 성패는 캘리브레이션이 좌우합니다. “대충 100배치 흘리면 되겠지”는 모델에 따라 크게 실패합니다.
1) 대표성 있는 데이터
- 실서비스 입력 분포를 반영
- 이상치(outlier)가 섞여 있으면 activation range가 넓어져 정밀도가 떨어질 수 있음
실무 팁:
- 최근 1일 트래픽에서 랜덤 샘플링
- 카테고리/언어/길이 등 조건별 stratified sampling
2) 캘리브레이션 배치 수
- 너무 적으면 range 추정이 불안정
- 너무 많으면 시간만 낭비(PTQ 자체가 느려짐)
경험적으로는 50에서 500 배치 사이에서 먼저 스윕해 보고, 정확도와 속도를 함께 보면서 결정하는 편이 안전합니다.
3) Per-channel weight quantization
기본 qconfig는 보통 weight를 per-channel로 잡아주는 경우가 많아 정확도에 유리합니다. 커스텀 설정을 하다가 per-tensor로 내려가면 정확도가 크게 흔들릴 수 있습니다.
자주 만나는 문제: INT8로 바꿨는데 느려진다
INT8로 변환했는데 오히려 느려지는 경우는 대체로 아래 중 하나입니다.
1) 지원되지 않는 연산자가 많아 그래프가 쪼개짐
양자화된 구간과 FP32 구간이 번갈아 나오면 quant/dequant 오버헤드가 생깁니다. 특히 작은 텐서에서 이 비용이 상대적으로 커집니다.
해결 방향:
- 모델 구조를 단순화(가능하면
Linear-ReLU-Linear같은 패턴 유지) - 양자화 친화적인 블록으로 교체
- 필요 시 특정 서브모듈만 선택적으로 양자화
2) 배치가 너무 작고 Python 오버헤드가 큼
- 전처리/후처리 시간이 전체의 큰 비중이면 INT8 이득이 가려집니다.
- 이때는 전처리 최적화, 배치 전략, 또는 엔드투엔드 프로파일링이 먼저입니다.
웹/서버 프레임워크에서 캐시가 꼬여 성능 측정이 일관되지 않을 때도 있는데, 프론트 렌더링에서 비슷한 문제를 겪는다면 Next.js 14 RSC 캐시 꼬임으로 갱신이 안될 때처럼 “캐시가 원인인지”를 분리해 보는 사고방식이 도움이 됩니다.
3) 스레드/런타임 설정 문제
OMP_NUM_THREADS,MKL_NUM_THREADS등 환경 변수- 컨테이너 CPU limit
- NUMA 바인딩
특히 쿠버네티스에서 CPU request/limit이 애매하면 벤치마크가 크게 흔들립니다.
torch.compile과 INT8 PTQ를 같이 쓰면 더 빨라질까
PyTorch 2.x에서 torch.compile은 FP32 모델에도 큰 이득을 줄 수 있지만, 양자화 모델과의 궁합은 버전/모델/백엔드에 따라 다릅니다.
실무적으로는 아래 순서가 안전합니다.
- FP32에서
torch.compile로 baseline 가속 - INT8 PTQ 적용 후 성능 측정
- INT8 모델에
torch.compile을 적용해 추가 이득이 있는지 실험
단, 양자화 모델은 내부적으로 quantized op를 사용하며, 컴파일러가 이를 어떻게 다루는지에 따라 이득이 없거나 오히려 손해일 수 있습니다. 따라서 “항상 더 빨라진다”는 접근은 위험하고, 반드시 측정으로 결정해야 합니다.
배포 관점 체크리스트
마지막으로, “로컬에서는 3배인데 서버에서는 1.3배” 같은 상황을 줄이려면 아래를 체크하세요.
- 동일한 CPU 아키텍처인가(로컬은 최신, 서버는 구형이면 차이 큼)
fbgemm사용 여부- 스레드 수와 CPU quota 고정 여부
- 입력 shape가 실제 트래픽과 동일한가
- 캘리브레이션 데이터가 실제 분포를 반영하는가
그리고 모델 파일 저장/로딩도 잊지 마세요. FX 변환 모델은 보통 state_dict 저장이 가능하지만, 환경 차이를 줄이려면 torch.jit 또는 export 계열로 아티팩트를 고정하는 전략도 검토할 만합니다(프로젝트 요구사항에 따라 선택).
정리
PyTorch 2.0+에서 PTQ로 INT8 변환은 “학습 없이” CPU 추론을 크게 가속할 수 있는 가장 비용 효율적인 방법입니다. 다만 3배 가속을 안정적으로 얻으려면 다음이 핵심입니다.
fbgemm등 올바른 backend 선택- 대표성 있는 캘리브레이션 데이터
- 워밍업/스레드 고정 등 공정한 벤치마크
- 양자화 가능한 연산자 비중 확보(그래프 분할 최소화)
이 4가지만 제대로 잡아도, 많은 MLP/추천/일부 Transformer 계열 워크로드에서 “현실적인 2배에서 4배” 구간에 진입할 가능성이 높습니다.