- Published on
PyTorch 2 QAT로 INT8 2배 가속 실전 가이드
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서버 추론에서 병목은 대개 GEMM(행렬곱)과 메모리 대역폭입니다. FP32 모델은 정확도는 좋지만 비용이 큽니다. INT8 양자화는 연산량과 메모리 트래픽을 줄여 CPU 추론을 크게 가속할 수 있고, 잘 맞는 모델에서는 체감상 2x 내외의 속도 개선이 나오는 경우가 많습니다.
다만 “그냥 PTQ(Post Training Quantization)로 끝”이 되는 경우도 있지만, 정확도 손실이 크거나(특히 작은 모델, 민감한 분류 임계값, 검출/세그멘테이션), 활성값 분포가 까다로운 모델에서는 QAT(Quantization Aware Training)가 훨씬 안정적입니다. 이 글은 PyTorch 2 기준으로 FX Graph Mode QAT 파이프라인을 정리하고, 실제로 성능이 나오는 설정과 흔한 실패 지점을 함께 다룹니다.
QAT를 선택해야 하는 기준
PTQ로 충분한 경우
- 데이터 분포가 안정적이고, 모델이 크며, 약간의 정확도 손실이 허용됨
- ReLU 기반 CNN, 전형적인 분류 모델
- 대표 샘플로 calibration(관측)만 잘 해도 성능이 유지됨
QAT가 유리한 경우
- PTQ에서 정확도 하락이 크거나 편차가 큼
- 작은 모델(모바일/엣지), 임계값 기반 비즈니스 로직(예: fraud score)
- 활성값이 민감한 구조(잔차 연결이 많거나, 분포가 넓은 경우)
- 운영에서 입력 분포가 조금씩 드리프트할 수 있음
QAT는 학습 중에 “가짜 양자화(fake quant)”를 삽입해 INT8의 반올림/클리핑 효과를 모델이 미리 학습하도록 만듭니다. 결과적으로 INT8 변환 후 정확도 보존이 좋아집니다.
PyTorch 2 양자화 스택: 꼭 알아야 할 변화
PyTorch 2 계열에서는 FX Graph Mode(현재는 torch.ao.quantization 중심)가 사실상 표준입니다. 핵심 포인트는 다음입니다.
qconfig와 observer 설정이 성능/정확도를 좌우- backend는 CPU라면 보통
x86(oneDNN) 또는fbgemm계열을 사용 - 변환 결과는 “양자화된 연산자(quantized ops)”로 치환되어야 실제 속도가 남
또한 PyTorch 2의 torch.compile은 모든 양자화 경로에서 만능은 아닙니다. 양자화된 연산자는 backend에 따라 최적화 경로가 다르므로, “컴파일을 켜면 더 빨라지겠지”라고 단정하면 벤치에서 오히려 손해를 볼 수 있습니다. 결론은 간단합니다. 반드시 벤치마킹으로 확인해야 합니다.
목표: INT8에서 2x를 만들기 위한 체크리스트
2x는 마법이 아니라 조건이 맞아야 합니다.
- CPU가 INT8 벡터/행렬 최적화(예: AVX2/VNNI 등)를 지원
- 모델이 Conv/Linear 비중이 높고, 양자화 가능한 블록이 많음
- 입력 배치/시퀀스 길이가 backend 최적화에 유리함
- 양자화 후에도 실제로
quantized::conv2d/quantized::linear같은 커널로 내려감
반대로 아래 조건이면 기대치가 떨어집니다.
- 작은 텐서 위주(오버헤드가 커짐)
- 연산이 대부분 LayerNorm/Softmax/Attention 등 비양자화 구간
- 동적 shape로 인해 최적화가 깨짐
실습: FX Graph Mode QAT 파이프라인
아래 예시는 작은 MLP 형태로 설명하지만, 실제 프로젝트에서는 ConvNet이나 간단한 Transformer 블록에도 동일한 흐름을 적용합니다.
1) 준비: backend와 qconfig 선택
x86 서버 CPU라면 보통 x86 또는 fbgemm를 씁니다(환경에 따라 다름). PyTorch 버전에 따라 추천 backend 문자열이 달라질 수 있으니, 설치된 버전 문서를 확인하세요.
import torch
import torch.nn as nn
import torch.ao.quantization as aoq
# 예: x86(oneDNN) 또는 fbgemm 중 하나를 선택
# 환경에 따라 지원이 다를 수 있습니다.
backend = "x86" # 또는 "fbgemm"
# QAT 기본 qconfig
qconfig = aoq.get_default_qat_qconfig(backend)
qconfig는 observer(스케일/제로포인트 추정 방식)와 fake quant 모듈을 결정합니다. 속도만 보고 무작정 per-channel을 고르면 정확도는 좋아질 수 있지만 변환/커널 제약이 생길 수도 있습니다. 일반적으로 Linear/Conv weight는 per-channel이 유리한 경우가 많습니다.
2) 모델 구성: fuse 가능한 패턴을 만들기
QAT/INT8에서 속도를 내려면 fuse가 중요합니다. 대표적으로 Conv + BN + ReLU 같은 패턴이 fuse되면 양자화 커널이 효율적으로 동작합니다.
class SmallNet(nn.Module):
def __init__(self, in_dim=128, hidden=256, out_dim=10):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, out_dim),
)
def forward(self, x):
return self.net(x)
model_fp32 = SmallNet().train()
ConvNet이면 Conv2d와 BatchNorm2d를 명시적으로 분리해두는 편이 fuse에 유리합니다.
3) prepare_qat_fx: FX로 fake quant 삽입
FX Graph Mode QAT는 prepare_qat_fx로 그래프를 변환합니다. 이때 예시 입력을 기반으로 트레이싱이 진행됩니다.
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
example_inputs = (torch.randn(32, 128),)
qconfig_mapping = aoq.QConfigMapping().set_global(qconfig)
prepared = prepare_qat_fx(model_fp32, qconfig_mapping, example_inputs)
여기서 중요한 점:
- example input shape가 실제 서빙 shape와 크게 다르면, 관측/스케일이 어긋날 수 있습니다.
- 모델에 data-dependent control flow가 많으면 FX 변환이 어려울 수 있습니다.
4) QAT 파인튜닝: 짧게, 하지만 “충분히”
QAT는 처음부터 길게 학습할 필요가 없는 경우가 많습니다. 보통은 FP32 pretrained를 가져와서 짧게 파인튜닝합니다.
import torch.optim as optim
optimizer = optim.AdamW(prepared.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
prepared.train()
for step in range(200):
x = torch.randn(32, 128)
y = torch.randint(0, 10, (32,))
optimizer.zero_grad(set_to_none=True)
logits = prepared(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
실전 팁:
- 초반 몇 step은 observer가 분포를 잡는 구간입니다. 너무 짧으면 스케일이 불안정합니다.
- 학습 후반에 observer/fake quant를 고정(freeze)하면 수렴이 안정적입니다.
# observer 고정
prepared.apply(aoq.disable_observer)
# fake quant 고정(선택)
prepared.apply(aoq.disable_fake_quant)
프로젝트에 따라 “observer만 고정하고 fake quant는 유지”가 더 좋은 경우도 있습니다.
5) convert_fx: 진짜 INT8 모델로 변환
학습이 끝나면 변환합니다.
prepared.eval()
quantized_model = convert_fx(prepared)
이 시점에 실제로 양자화 커널로 내려갔는지 확인해야 합니다. 가장 빠른 방법은 모델 출력(print)에서 quantized::linear 같은 연산이 보이는지 확인하거나, profiler로 커널을 확인하는 것입니다.
벤치마킹: “정말 2x인가”를 검증하는 방법
속도 측정은 워밍업, 스레드 수, affinity, 입력 shape 고정 여부에 따라 결과가 크게 바뀝니다.
1) 간단한 wall-clock 벤치
import time
def bench(model, x, iters=200, warmup=50):
model.eval()
with torch.inference_mode():
for _ in range(warmup):
_ = model(x)
t0 = time.perf_counter()
for _ in range(iters):
_ = model(x)
t1 = time.perf_counter()
return (t1 - t0) / iters
x = torch.randn(256, 128)
t_fp32 = bench(model_fp32.eval(), x)
# quantized_model은 내부적으로 quant/dequant가 포함될 수 있어 입력 dtype은 보통 fp32로 시작
t_int8 = bench(quantized_model, x)
print("fp32:", t_fp32, "sec/iter")
print("int8:", t_int8, "sec/iter")
print("speedup:", t_fp32 / t_int8)
2) 스레딩 설정
CPU 추론은 스레드 수에 민감합니다.
import os
import torch
torch.set_num_threads(int(os.getenv("OMP_NUM_THREADS", "8")))
torch.set_num_interop_threads(1)
운영 환경에서는 컨테이너 CPU quota와 함께 설정해야 재현됩니다. 컨테이너 빌드/배포 파이프라인에서 캐시나 설정이 꼬여 성능 측정이 흔들리는 경우도 많습니다. 빌드가 느려져 실험 사이클이 길어지면 최적화 자체가 지연되므로, Docker 캐시 문제도 같이 점검해두면 좋습니다.
정확도 검증: “평균 정확도”만 보면 놓치는 것들
QAT는 보통 평균 metric을 잘 복원하지만, 운영에서는 다음이 더 중요할 수 있습니다.
- 클래스별 precision/recall 변화
- 임계값 근처 샘플의 score drift
- calibration set과 실제 트래픽 분포 차이
권장 루틴:
- FP32 vs INT8의 confusion matrix 비교
- score histogram과 percentile 비교
- 오탐/미탐 상위 N 샘플을 모아 케이스 리뷰
흔한 함정 7가지와 해결법
1) 변환했는데도 빨라지지 않음
원인:
- 실제로 quantized op로 치환되지 않음
- 모델 대부분이 비양자화 연산
- batch가 너무 작아 오버헤드가 큼
대응:
- 변환 후 그래프/프로파일러로
quantized::커널 확인 - fuse 패턴을 늘리거나(Conv-BN-ReLU), 모델 구조를 양자화 친화적으로 수정
2) INT8인데 정확도가 크게 떨어짐
원인:
- 대표 입력 분포가 학습/운영과 다름
- observer 설정이 부적절
- activation clipping이 심함
대응:
- QAT 파인튜닝 step을 늘리고, observer freeze 타이밍 조절
- per-channel weight quant 적용 여부 검토
3) LayerNorm, Softmax가 병목
Transformer 계열에서 INT8로 “전체가” 빨라지지 않는 대표 이유입니다. attention 블록의 핵심 연산이 양자화 커널로 잘 떨어지지 않으면 체감이 제한됩니다.
대응:
- CPU 추론이면 sequence 길이를 줄이거나, KV 캐시 등 구조 최적화 병행
- 모델을 Conv/MLP 중심으로 바꾸는 것이 현실적인 경우도 있음
4) torch.compile과의 조합에서 성능이 들쭉날쭉
대응:
- 컴파일 on/off 둘 다 벤치하고, 더 나은 쪽을 선택
- 동적 shape를 줄여 재컴파일/가드 비용을 감소
5) 서빙에서 입력 dtype/스케일 전처리가 달라짐
양자화는 입력 분포에 민감합니다. 학습에서는 정규화했는데 서빙에서 누락되면 바로 무너집니다.
대응:
- 전처리를 모델 안으로 넣거나(가능하면), 전처리 버전을 강제
- 추론 파이프라인에 스모크 테스트 추가
6) 모델 저장/로드 후 성능 변화
대응:
state_dict저장 후 로드 시, quantized 모델의 로딩 경로를 검증- PyTorch 버전/oneDNN 버전 차이로 커널이 달라질 수 있으니 런타임 고정
7) CI에서 재현이 안 됨
양자화는 하드웨어 특성 영향을 크게 받습니다. CI CPU와 프로덕션 CPU가 다르면 결과가 달라집니다.
대응:
성능 벤치는 프로덕션과 유사한 runner에서 수행
배포 자동화 중 권한/토큰 문제로 실험 파이프라인이 끊기면 속도 개선 작업이 지연됩니다. CI 권한 이슈도 미리 정리해두면 좋습니다.
운영 배포 팁: “빠른 모델”을 “빠르게” 배포하기
모델 아티팩트에 다음 메타데이터를 함께 저장
- PyTorch 버전, backend(
x86/fbgemm), 학습 데이터 버전, 입력 정규화 파라미터
- PyTorch 버전, backend(
INT8 모델은 FP32 대비 민감할 수 있으니, 카나리 배포에서 지표를 더 촘촘히 보세요.
Next.js 기반 대시보드/콘솔에서 추론 결과를 렌더링한다면, 서버/클라 렌더 불일치로 디버깅이 어려워질 수 있습니다. 관측 UI가 흔들리면 성능/정확도 회귀를 놓치기 쉽습니다.
결론: 2x는 가능하지만, “커널로 내려갔는지”가 전부다
PyTorch 2에서 QAT로 INT8을 만들면 정확도 손실을 최소화하면서 CPU 추론 성능을 크게 끌어올릴 수 있습니다. 하지만 속도는 “INT8로 변환했다”가 아니라, 핵심 연산이 실제 INT8 커널로 실행되는지에 달려 있습니다.
실전에서는 다음 순서로 접근하면 시행착오가 줄어듭니다.
- PTQ로 빠르게 가능성 확인
- 정확도 문제가 있으면 FX QAT로 짧게 파인튜닝
- 변환 후
quantized::커널 확인 - 워밍업/스레드/입력 shape를 고정한 벤치로
2x검증 - 운영 분포에서 score drift까지 포함해 검증
이 루틴을 지키면, “INT8인데 왜 안 빨라요” 같은 질문을 데이터와 프로파일링으로 깔끔하게 끝낼 수 있습니다.