Published on

PyTorch QAT로 INT8 양자화 시 정확도 하락 대처

Authors

서빙 비용과 지연 시간을 줄이려고 INT8 양자화를 붙였는데, 막상 QAT를 켜면 정확도가 예상보다 크게 떨어지는 경우가 많습니다. 특히 PTQ(사후 양자화)에서는 괜찮다가 QAT로 넘어가면 갑자기 수렴이 불안정해지거나, 특정 클래스만 대규모로 틀리는 패턴이 나옵니다.

이 글은 PyTorch의 eager mode QAT(대표적으로 torch.ao.quantization) 기준으로, 정확도 하락의 흔한 원인과 “무엇을 어떻게 바꿔야 하는지”를 체크리스트처럼 정리합니다. 마지막에는 바로 복사해 쓸 수 있는 코드 스니펫(모델 fuse, qconfig 설정, 학습 루프 팁)도 포함합니다.

참고로, 성능 튜닝은 결국 “병목을 빨리 진단하고 가설을 줄이는 일”입니다. 비슷한 디버깅 접근법은 인프라 쪽에서도 동일하게 통합니다. 예를 들어 네트워크 타임아웃을 10분 안에 쪼개서 보는 방식은 이 글의 진단 흐름과 닮아 있습니다: EKS Pod→RDS 504 타임아웃 - SG·NACL·NAT 10분 진단

QAT에서 정확도가 떨어지는 대표 원인 7가지

1) fuse 누락: Conv-BN-ReLU를 합치지 않음

QAT는 fake quant가 삽입되면서 분포가 더 거칠어집니다. 이때 Conv + BN이 분리되어 있으면 BN의 스케일/시프트가 양자화 노이즈를 키우는 방향으로 작동하기 쉽습니다. 그래서 대부분의 CNN 계열은 학습 전 fuse_modules가 사실상 필수입니다.

증상:

  • QAT 시작 직후 loss가 튀거나, 정확도가 급락
  • 특히 BN이 많은 모델에서 심함

해결:

  • model.eval() 상태에서 fuse 후 다시 train()
  • fuse 가능한 블록을 모듈 구조에 맞게 정확히 지정

2) 관측기(Observer) 선택이 데이터/모델과 안 맞음

기본 MinMax 계열 observer는 outlier에 취약합니다. 활성값에 드문 큰 값이 섞이면 스케일이 커져서 대부분의 값이 양자화 그리드에서 압축되어 정보가 날아갑니다.

해결 방향:

  • 활성값에 outlier가 많으면 HistogramObserverMovingAverageMinMaxObserver 고려
  • 가중치는 보통 PerChannel을 우선 고려

3) per-tensor로 weight를 양자화해서 채널별 스케일 차이를 못 따라감

Conv weight는 채널마다 분포가 크게 다릅니다. per-tensor로 묶어버리면 작은 채널은 양자화 단계가 너무 거칠어져 성능이 떨어집니다.

해결:

  • weight는 per_channel_symmetric를 우선 적용

4) 학습 레시피 문제: LR/스케줄/워밍업이 QAT에 불리

QAT는 학습 중 fake quant로 인해 gradient가 더 noisy합니다. 기존 FP32 레시피를 그대로 쓰면 수렴이 깨질 수 있습니다.

자주 먹히는 처방:

  • QAT 구간에서 LR을 낮추거나, QAT 시작 시점에 LR drop
  • 몇 epoch는 BN freeze(또는 BN 통계 고정) 시도
  • weight decay를 약간 줄이는 것도 도움이 되는 경우가 있음

5) 활성 함수/연산이 양자화 친화적이지 않음

대표적으로 SiLU/Swish, GELU 같은 비선형은 양자화 시 오차가 커질 수 있습니다(백엔드/패턴 매칭에 따라 다름). 또 add/mul이 많은 구조(잔차/어텐션)에서 스케일 정렬이 꼬이면 손실이 커집니다.

해결:

  • 가능하면 ReLU 계열로 단순화하거나, 해당 블록만 FP로 남기는 혼합 전략
  • 문제 레이어만 float()로 유지하는 “부분 양자화”도 실전에서 자주 씁니다

6) 캘리브레이션/옵저버 통계가 대표성을 잃음

QAT라도 observer 통계는 중요합니다. 학습 데이터 분포가 서빙 분포와 다르면, fake quant가 학습 중 맞춰진 분포를 실제 입력이 깨버립니다.

해결:

  • QAT 중간에 disable_observer/enable_observer를 적절히 사용
  • 최종 변환 전, 대표 배치로 짧게 observer 통계를 재수집(미세 캘리브레이션)

7) 백엔드 설정 불일치: fbgemm/qnnpack 및 dtype/engine

x86 서버는 보통 fbgemm, ARM/모바일은 qnnpack을 씁니다. 엔진이 바뀌면 지원 패턴과 수치가 달라서 결과도 달라질 수 있습니다.

해결:

  • 개발/평가 환경에서 torch.backends.quantized.engine을 명시

기본 골격: QAT 파이프라인을 “정확한 순서”로 고정

아래 순서가 틀어지면 정확도뿐 아니라 속도도 기대와 달라질 수 있습니다.

  1. 백엔드 엔진 설정
  2. 모델 eval()로 전환
  3. fuse_modules
  4. qconfig 설정
  5. prepare_qat
  6. train()로 QAT 학습
  7. eval() 전환 후 convert

실전 코드: fuse + qconfig + QAT 학습 + convert

아래 예시는 Conv-BN-ReLU 블록을 가진 전형적인 CNN을 가정합니다. 실제 모델 구조에 맞게 fuse 대상 모듈 이름을 수정해야 합니다.

import copy
import torch
import torch.nn as nn
import torch.ao.quantization as tq

# 1) 백엔드 엔진 설정 (x86 서버라면 보통 fbgemm)
# ARM이면 qnnpack을 고려
torch.backends.quantized.engine = "fbgemm"

class SmallCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.relu(self.bn(self.conv(x)))
        x = self.pool(x).flatten(1)
        return self.fc(x)

model_fp32 = SmallCNN()

# 2) fuse는 eval 상태에서 수행하는 것이 안전
model_to_qat = copy.deepcopy(model_fp32)
model_to_qat.eval()

# 3) fuse: Conv + BN + ReLU
# 모듈 이름은 모델 정의에 맞게 지정
model_to_qat = tq.fuse_modules(model_to_qat, [["conv", "bn", "relu"]], inplace=True)

# 4) qconfig 설정
# weight는 per-channel을 우선, activation은 moving average 계열이 무난
qconfig = tq.QConfig(
    activation=tq.MovingAverageMinMaxObserver.with_args(
        dtype=torch.quint8,
        qscheme=torch.per_tensor_affine,
        reduce_range=False,
    ),
    weight=tq.MovingAveragePerChannelMinMaxObserver.with_args(
        dtype=torch.qint8,
        qscheme=torch.per_channel_symmetric,
        ch_axis=0,
    ),
)
model_to_qat.qconfig = qconfig

# 5) prepare_qat
model_to_qat.train()
model_prepared = tq.prepare_qat(model_to_qat, inplace=False)

# --- QAT 학습 루프 (예시) ---
optimizer = torch.optim.SGD(model_prepared.parameters(), lr=1e-3, momentum=0.9)
criterion = nn.CrossEntropyLoss()

def train_one_epoch(loader):
    model_prepared.train()
    for images, labels in loader:
        optimizer.zero_grad(set_to_none=True)
        logits = model_prepared(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

# ... 여러 epoch 학습 ...

# 6) convert
model_prepared.eval()
model_int8 = tq.convert(model_prepared, inplace=False)

위 코드에서 정확도 하락이 크다면, 아래 섹션의 “정확도 회복 레버”를 우선순위대로 적용해 보세요.

정확도 회복 레버: 우선순위 높은 것부터

A) weight를 per-channel로 강제하고, activation observer를 바꿔보기

가장 흔한 승부처입니다. 특히 Conv가 많은 모델에서 per-channel 적용만으로도 정확도가 크게 회복되는 경우가 많습니다.

변형 예시(activation을 histogram으로):

qconfig = tq.QConfig(
    activation=tq.HistogramObserver.with_args(
        dtype=torch.quint8,
        qscheme=torch.per_tensor_affine,
        bins=2048,
    ),
    weight=tq.PerChannelMinMaxObserver.with_args(
        dtype=torch.qint8,
        qscheme=torch.per_channel_symmetric,
        ch_axis=0,
    ),
)
model_prepared.qconfig = qconfig

주의:

  • HistogramObserver는 준비/학습 비용이 늘 수 있습니다.
  • 데이터가 크면 bins를 줄이거나 moving average 계열로 타협합니다.

B) QAT 시작 시점과 LR을 조정하기: “FP32 수렴 후 QAT”

처음부터 QAT로 학습하면 불안정할 수 있습니다. 실전에서는 다음 패턴이 안정적입니다.

  • FP32로 대부분 수렴(예: 전체 epoch의 70%까지)
  • 마지막 30% 구간만 QAT
  • QAT 구간에서 LR을 1/10 또는 1/20로 낮춤

간단한 스위치 예시:

# pseudo
if epoch == qat_start_epoch:
    for g in optimizer.param_groups:
        g["lr"] *= 0.1

C) observer 통계 수집 구간을 컨트롤하기

학습 내내 observer가 계속 업데이트되면 분포가 흔들려 성능이 불안정해질 수 있습니다.

  • 초반 몇 epoch만 observer를 켜고
  • 이후에는 observer를 끄고(fake quant는 유지)
import torch.ao.quantization as tq

# 예: 0~2 epoch observer ON, 이후 OFF
if epoch == 2:
    model_prepared.apply(tq.disable_observer)
    # fake quant는 유지되므로 QAT 특성은 남습니다

반대로, 학습 막판에 짧게 observer를 다시 켜서 “대표 배치로 재수집”하는 것도 도움이 됩니다.

D) 문제 레이어만 FP로 남기는 부분 양자화

모든 레이어를 INT8로 만드는 것이 항상 최선은 아닙니다. 정확도 민감 구간(예: 첫 conv, 마지막 fc, 특정 residual 블록)을 FP로 남기면 정확도는 살리고 속도 이득도 일부 가져갈 수 있습니다.

전략:

  • 모델을 서브모듈 단위로 나누고, 양자화 적용 범위를 제한
  • 또는 해당 블록의 qconfig = None으로 제외
# 예: 마지막 fc는 FP로 남기기
model_to_qat.fc.qconfig = None

E) BN 처리: freeze 또는 fold 이후 QAT

BN이 특히 민감한 모델은 다음을 시도합니다.

  • QAT 구간에서 BN을 eval()로 두고 통계를 고정
  • 또는 fuse 이후 BN이 사라지므로, fuse 타이밍을 명확히

간단 예시:

def set_bn_eval(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()

model_prepared.apply(set_bn_eval)

“정확도는 떨어지는데, 왜 어떤 클래스만 망가질까?”

INT8에서 특정 클래스만 급격히 무너질 때는 대개 다음 중 하나입니다.

  • 입력 분포의 특정 모드(예: 어두운 이미지, 특정 길이 문장)에서 activation outlier가 증가
  • 마지막 분류기 근처에서 스케일 정렬 문제로 logit margin이 줄어듦
  • 소수 클래스에서만 나타나는 feature가 양자화 그리드에서 소실

이럴 때는 전체 top-1만 보지 말고,

  • 클래스별 confusion
  • 입력 조건별 slice metric
  • 레이어별 activation histogram 을 함께 봐야 원인을 빨리 찾습니다.

검색/랭킹 쪽에서도 “전체 지표는 비슷한데 특정 쿼리군만 망가지는” 일이 흔합니다. 그런 경우 파라미터를 한 번에 바꾸기보다, 실패 케이스를 군집화해서 원인을 줄이는 게 효과적입니다: Pinecone·Milvus 검색품질 튜닝 - HNSW 파라미터

디버깅 체크리스트: 15분 안에 보는 항목

1) FP32 기준선 재현성

  • 같은 seed에서 FP32 정확도 재현되는지
  • 데이터 전처리/증강이 서빙과 과도하게 다르지 않은지

2) fuse 적용 여부

  • Conv-BN-ReLU가 실제로 fuse 되었는지(모듈 출력 확인)

3) qconfig 확인

  • weight가 per_channel인지
  • activation observer가 outlier에 취약한 설정인지

4) QAT 학습 안정성

  • QAT 시작 시 loss spike 여부
  • LR drop 적용 여부

5) convert 후 정확도

  • prepare_qat 상태에서의 fake quant 정확도와
  • convert 후 실제 INT8 정확도를 분리해서 비교

만약 fake quant 단계에서 이미 떨어지면 학습/옵저버/레시피 문제일 확률이 높고, convert 후에만 떨어지면 백엔드/패턴 매칭/지원 연산 문제일 확률이 큽니다.

흔한 함정: 평가 시 eval() 누락과 드롭아웃

QAT 학습 중에는 train()이지만, 평가/convert 전에는 반드시 eval()로 바꿔야 합니다. 드롭아웃/BN이 섞이면 “양자화가 문제인 것처럼 보이는” 가짜 정확도 하락이 나옵니다.

또한 실험 자동화 파이프라인에서는 이런 상태 버그가 캐시/배포 과정과 얽혀 찾기 어렵습니다. CI에서 아티팩트가 섞이거나 캐시가 안 먹으면 실험 결과가 뒤틀리기도 합니다: GitLab CI 캐시 안 먹힐 때 key·policy·권한 점검

마무리: 정확도 하락을 “구조적으로” 줄이는 접근

정확도 하락을 한 번에 해결하려고 observer, LR, fuse, 엔진을 동시에 바꾸면 원인 규명이 불가능해집니다. 다음 순서로 한 가지씩만 바꾸는 것을 권합니다.

  1. fuse 적용 확인
  2. weight per-channel 적용
  3. activation observer를 moving average 또는 histogram으로 변경
  4. QAT 구간 LR drop 및 QAT 시작 시점 조정
  5. observer enable/disable 스케줄링
  6. 부분 양자화로 민감 레이어 제외

이 과정을 거치면, 대부분의 “QAT로 INT8 만들었더니 정확도 폭락” 케이스는 재현 가능하고 해결 가능한 문제로 바뀝니다.

원하시면 모델 타입(예: ResNet, MobileNetV3, ViT, Transformer encoder), 타깃 백엔드(서버 fbgemm인지, 모바일 qnnpack인지), 그리고 현재 쓰는 qconfig/학습 레시피를 기준으로 더 구체적인 처방(예: 어떤 블록을 FP로 남길지, 어떤 observer가 유리한지)까지 맞춰서 정리해 드릴 수 있습니다.