Published on

PyTorch 2 QAT로 INT8 2배 가속 실전 가이드

Authors

서버 추론에서 병목은 대개 GEMM(행렬곱)과 메모리 대역폭입니다. FP32 모델은 정확도는 좋지만 비용이 큽니다. INT8 양자화는 연산량과 메모리 트래픽을 줄여 CPU 추론을 크게 가속할 수 있고, 잘 맞는 모델에서는 체감상 2x 내외의 속도 개선이 나오는 경우가 많습니다.

다만 “그냥 PTQ(Post Training Quantization)로 끝”이 되는 경우도 있지만, 정확도 손실이 크거나(특히 작은 모델, 민감한 분류 임계값, 검출/세그멘테이션), 활성값 분포가 까다로운 모델에서는 QAT(Quantization Aware Training)가 훨씬 안정적입니다. 이 글은 PyTorch 2 기준으로 FX Graph Mode QAT 파이프라인을 정리하고, 실제로 성능이 나오는 설정과 흔한 실패 지점을 함께 다룹니다.

QAT를 선택해야 하는 기준

PTQ로 충분한 경우

  • 데이터 분포가 안정적이고, 모델이 크며, 약간의 정확도 손실이 허용됨
  • ReLU 기반 CNN, 전형적인 분류 모델
  • 대표 샘플로 calibration(관측)만 잘 해도 성능이 유지됨

QAT가 유리한 경우

  • PTQ에서 정확도 하락이 크거나 편차가 큼
  • 작은 모델(모바일/엣지), 임계값 기반 비즈니스 로직(예: fraud score)
  • 활성값이 민감한 구조(잔차 연결이 많거나, 분포가 넓은 경우)
  • 운영에서 입력 분포가 조금씩 드리프트할 수 있음

QAT는 학습 중에 “가짜 양자화(fake quant)”를 삽입해 INT8의 반올림/클리핑 효과를 모델이 미리 학습하도록 만듭니다. 결과적으로 INT8 변환 후 정확도 보존이 좋아집니다.

PyTorch 2 양자화 스택: 꼭 알아야 할 변화

PyTorch 2 계열에서는 FX Graph Mode(현재는 torch.ao.quantization 중심)가 사실상 표준입니다. 핵심 포인트는 다음입니다.

  • qconfig와 observer 설정이 성능/정확도를 좌우
  • backend는 CPU라면 보통 x86(oneDNN) 또는 fbgemm 계열을 사용
  • 변환 결과는 “양자화된 연산자(quantized ops)”로 치환되어야 실제 속도가 남

또한 PyTorch 2의 torch.compile은 모든 양자화 경로에서 만능은 아닙니다. 양자화된 연산자는 backend에 따라 최적화 경로가 다르므로, “컴파일을 켜면 더 빨라지겠지”라고 단정하면 벤치에서 오히려 손해를 볼 수 있습니다. 결론은 간단합니다. 반드시 벤치마킹으로 확인해야 합니다.

목표: INT8에서 2x를 만들기 위한 체크리스트

2x는 마법이 아니라 조건이 맞아야 합니다.

  • CPU가 INT8 벡터/행렬 최적화(예: AVX2/VNNI 등)를 지원
  • 모델이 Conv/Linear 비중이 높고, 양자화 가능한 블록이 많음
  • 입력 배치/시퀀스 길이가 backend 최적화에 유리함
  • 양자화 후에도 실제로 quantized::conv2d / quantized::linear 같은 커널로 내려감

반대로 아래 조건이면 기대치가 떨어집니다.

  • 작은 텐서 위주(오버헤드가 커짐)
  • 연산이 대부분 LayerNorm/Softmax/Attention 등 비양자화 구간
  • 동적 shape로 인해 최적화가 깨짐

실습: FX Graph Mode QAT 파이프라인

아래 예시는 작은 MLP 형태로 설명하지만, 실제 프로젝트에서는 ConvNet이나 간단한 Transformer 블록에도 동일한 흐름을 적용합니다.

1) 준비: backend와 qconfig 선택

x86 서버 CPU라면 보통 x86 또는 fbgemm를 씁니다(환경에 따라 다름). PyTorch 버전에 따라 추천 backend 문자열이 달라질 수 있으니, 설치된 버전 문서를 확인하세요.

import torch
import torch.nn as nn
import torch.ao.quantization as aoq

# 예: x86(oneDNN) 또는 fbgemm 중 하나를 선택
# 환경에 따라 지원이 다를 수 있습니다.
backend = "x86"  # 또는 "fbgemm"

# QAT 기본 qconfig
qconfig = aoq.get_default_qat_qconfig(backend)

qconfig는 observer(스케일/제로포인트 추정 방식)와 fake quant 모듈을 결정합니다. 속도만 보고 무작정 per-channel을 고르면 정확도는 좋아질 수 있지만 변환/커널 제약이 생길 수도 있습니다. 일반적으로 Linear/Conv weight는 per-channel이 유리한 경우가 많습니다.

2) 모델 구성: fuse 가능한 패턴을 만들기

QAT/INT8에서 속도를 내려면 fuse가 중요합니다. 대표적으로 Conv + BN + ReLU 같은 패턴이 fuse되면 양자화 커널이 효율적으로 동작합니다.

class SmallNet(nn.Module):
    def __init__(self, in_dim=128, hidden=256, out_dim=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x):
        return self.net(x)

model_fp32 = SmallNet().train()

ConvNet이면 Conv2dBatchNorm2d를 명시적으로 분리해두는 편이 fuse에 유리합니다.

3) prepare_qat_fx: FX로 fake quant 삽입

FX Graph Mode QAT는 prepare_qat_fx로 그래프를 변환합니다. 이때 예시 입력을 기반으로 트레이싱이 진행됩니다.

from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx

example_inputs = (torch.randn(32, 128),)

qconfig_mapping = aoq.QConfigMapping().set_global(qconfig)

prepared = prepare_qat_fx(model_fp32, qconfig_mapping, example_inputs)

여기서 중요한 점:

  • example input shape가 실제 서빙 shape와 크게 다르면, 관측/스케일이 어긋날 수 있습니다.
  • 모델에 data-dependent control flow가 많으면 FX 변환이 어려울 수 있습니다.

4) QAT 파인튜닝: 짧게, 하지만 “충분히”

QAT는 처음부터 길게 학습할 필요가 없는 경우가 많습니다. 보통은 FP32 pretrained를 가져와서 짧게 파인튜닝합니다.

import torch.optim as optim

optimizer = optim.AdamW(prepared.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

prepared.train()
for step in range(200):
    x = torch.randn(32, 128)
    y = torch.randint(0, 10, (32,))

    optimizer.zero_grad(set_to_none=True)
    logits = prepared(x)
    loss = criterion(logits, y)
    loss.backward()
    optimizer.step()

실전 팁:

  • 초반 몇 step은 observer가 분포를 잡는 구간입니다. 너무 짧으면 스케일이 불안정합니다.
  • 학습 후반에 observer/fake quant를 고정(freeze)하면 수렴이 안정적입니다.
# observer 고정
prepared.apply(aoq.disable_observer)

# fake quant 고정(선택)
prepared.apply(aoq.disable_fake_quant)

프로젝트에 따라 “observer만 고정하고 fake quant는 유지”가 더 좋은 경우도 있습니다.

5) convert_fx: 진짜 INT8 모델로 변환

학습이 끝나면 변환합니다.

prepared.eval()
quantized_model = convert_fx(prepared)

이 시점에 실제로 양자화 커널로 내려갔는지 확인해야 합니다. 가장 빠른 방법은 모델 출력(print)에서 quantized::linear 같은 연산이 보이는지 확인하거나, profiler로 커널을 확인하는 것입니다.

벤치마킹: “정말 2x인가”를 검증하는 방법

속도 측정은 워밍업, 스레드 수, affinity, 입력 shape 고정 여부에 따라 결과가 크게 바뀝니다.

1) 간단한 wall-clock 벤치

import time

def bench(model, x, iters=200, warmup=50):
    model.eval()
    with torch.inference_mode():
        for _ in range(warmup):
            _ = model(x)

        t0 = time.perf_counter()
        for _ in range(iters):
            _ = model(x)
        t1 = time.perf_counter()

    return (t1 - t0) / iters

x = torch.randn(256, 128)

t_fp32 = bench(model_fp32.eval(), x)
# quantized_model은 내부적으로 quant/dequant가 포함될 수 있어 입력 dtype은 보통 fp32로 시작
t_int8 = bench(quantized_model, x)

print("fp32:", t_fp32, "sec/iter")
print("int8:", t_int8, "sec/iter")
print("speedup:", t_fp32 / t_int8)

2) 스레딩 설정

CPU 추론은 스레드 수에 민감합니다.

import os
import torch

torch.set_num_threads(int(os.getenv("OMP_NUM_THREADS", "8")))
torch.set_num_interop_threads(1)

운영 환경에서는 컨테이너 CPU quota와 함께 설정해야 재현됩니다. 컨테이너 빌드/배포 파이프라인에서 캐시나 설정이 꼬여 성능 측정이 흔들리는 경우도 많습니다. 빌드가 느려져 실험 사이클이 길어지면 최적화 자체가 지연되므로, Docker 캐시 문제도 같이 점검해두면 좋습니다.

정확도 검증: “평균 정확도”만 보면 놓치는 것들

QAT는 보통 평균 metric을 잘 복원하지만, 운영에서는 다음이 더 중요할 수 있습니다.

  • 클래스별 precision/recall 변화
  • 임계값 근처 샘플의 score drift
  • calibration set과 실제 트래픽 분포 차이

권장 루틴:

  1. FP32 vs INT8의 confusion matrix 비교
  2. score histogram과 percentile 비교
  3. 오탐/미탐 상위 N 샘플을 모아 케이스 리뷰

흔한 함정 7가지와 해결법

1) 변환했는데도 빨라지지 않음

원인:

  • 실제로 quantized op로 치환되지 않음
  • 모델 대부분이 비양자화 연산
  • batch가 너무 작아 오버헤드가 큼

대응:

  • 변환 후 그래프/프로파일러로 quantized:: 커널 확인
  • fuse 패턴을 늘리거나(Conv-BN-ReLU), 모델 구조를 양자화 친화적으로 수정

2) INT8인데 정확도가 크게 떨어짐

원인:

  • 대표 입력 분포가 학습/운영과 다름
  • observer 설정이 부적절
  • activation clipping이 심함

대응:

  • QAT 파인튜닝 step을 늘리고, observer freeze 타이밍 조절
  • per-channel weight quant 적용 여부 검토

3) LayerNorm, Softmax가 병목

Transformer 계열에서 INT8로 “전체가” 빨라지지 않는 대표 이유입니다. attention 블록의 핵심 연산이 양자화 커널로 잘 떨어지지 않으면 체감이 제한됩니다.

대응:

  • CPU 추론이면 sequence 길이를 줄이거나, KV 캐시 등 구조 최적화 병행
  • 모델을 Conv/MLP 중심으로 바꾸는 것이 현실적인 경우도 있음

4) torch.compile과의 조합에서 성능이 들쭉날쭉

대응:

  • 컴파일 on/off 둘 다 벤치하고, 더 나은 쪽을 선택
  • 동적 shape를 줄여 재컴파일/가드 비용을 감소

5) 서빙에서 입력 dtype/스케일 전처리가 달라짐

양자화는 입력 분포에 민감합니다. 학습에서는 정규화했는데 서빙에서 누락되면 바로 무너집니다.

대응:

  • 전처리를 모델 안으로 넣거나(가능하면), 전처리 버전을 강제
  • 추론 파이프라인에 스모크 테스트 추가

6) 모델 저장/로드 후 성능 변화

대응:

  • state_dict 저장 후 로드 시, quantized 모델의 로딩 경로를 검증
  • PyTorch 버전/oneDNN 버전 차이로 커널이 달라질 수 있으니 런타임 고정

7) CI에서 재현이 안 됨

양자화는 하드웨어 특성 영향을 크게 받습니다. CI CPU와 프로덕션 CPU가 다르면 결과가 달라집니다.

대응:

  • 성능 벤치는 프로덕션과 유사한 runner에서 수행

  • 배포 자동화 중 권한/토큰 문제로 실험 파이프라인이 끊기면 속도 개선 작업이 지연됩니다. CI 권한 이슈도 미리 정리해두면 좋습니다.

  • GitHub Actions GITHUB_TOKEN 403 권한오류 해결

운영 배포 팁: “빠른 모델”을 “빠르게” 배포하기

  • 모델 아티팩트에 다음 메타데이터를 함께 저장

    • PyTorch 버전, backend(x86/fbgemm), 학습 데이터 버전, 입력 정규화 파라미터
  • INT8 모델은 FP32 대비 민감할 수 있으니, 카나리 배포에서 지표를 더 촘촘히 보세요.

  • Next.js 기반 대시보드/콘솔에서 추론 결과를 렌더링한다면, 서버/클라 렌더 불일치로 디버깅이 어려워질 수 있습니다. 관측 UI가 흔들리면 성능/정확도 회귀를 놓치기 쉽습니다.

  • Next.js Hydration Mismatch 5가지 원인과 해결법

결론: 2x는 가능하지만, “커널로 내려갔는지”가 전부다

PyTorch 2에서 QAT로 INT8을 만들면 정확도 손실을 최소화하면서 CPU 추론 성능을 크게 끌어올릴 수 있습니다. 하지만 속도는 “INT8로 변환했다”가 아니라, 핵심 연산이 실제 INT8 커널로 실행되는지에 달려 있습니다.

실전에서는 다음 순서로 접근하면 시행착오가 줄어듭니다.

  1. PTQ로 빠르게 가능성 확인
  2. 정확도 문제가 있으면 FX QAT로 짧게 파인튜닝
  3. 변환 후 quantized:: 커널 확인
  4. 워밍업/스레드/입력 shape를 고정한 벤치로 2x 검증
  5. 운영 분포에서 score drift까지 포함해 검증

이 루틴을 지키면, “INT8인데 왜 안 빨라요” 같은 질문을 데이터와 프로파일링으로 깔끔하게 끝낼 수 있습니다.