Published on

PyTorch QAT INT8 변환 후 정확도 하락 잡기

Authors

서버 추론 비용을 줄이려고 PyTorch의 QAT(Quantization Aware Training)로 INT8 변환을 적용했는데, 변환 직후 정확도가 눈에 띄게 하락하는 경우가 자주 있습니다. QAT는 원래 PTQ(Post Training Quantization)보다 정확도 보존이 유리하지만, 설정 한두 군데만 어긋나도 학습은 잘 되는 듯 보이다가 변환 후 성능이 급격히 무너질 수 있습니다.

이 글에서는 torch.ao.quantization 기반 QAT 파이프라인을 전제로, 정확도 하락을 “원인별로” 빠르게 좁히는 진단 순서재현 가능한 수정 패턴을 정리합니다.

추가로, 운영에서 흔히 겪는 “캐시/설정 불일치로 인한 실패”와 유사하게 QAT도 옵저버/스케일/백엔드 설정 불일치가 핵심 원인인 경우가 많습니다. 그런 종류의 트러블슈팅 감각은 JWT 검증 실패 - JWKS kid 불일치·캐시 7가지 같은 글의 접근법과도 닮아 있습니다.

먼저 확인: QAT 정확도 하락의 80%는 “불일치”

QAT에서 정확도 하락이 크게 나는 패턴은 대체로 다음 중 하나입니다.

  1. 학습 시점과 변환 시점의 설정 불일치

    • qconfig를 바꿨는데 다시 prepare_qat를 안 함
    • fuse_modules 전후 순서가 뒤섞임
    • 학습은 fbgemm 가정인데 변환/실행은 다른 백엔드
  2. 옵저버(Observer) 캘리브레이션 품질 문제

    • 학습 초기에 스케일이 요동치는데 freeze 타이밍이 너무 빠름
    • 분포가 긴 꼬리인 텐서에 MinMaxObserver를 써서 스케일이 망가짐
  3. 연산자/레이어가 양자화 친화적이지 않음

    • SiLU, GELU, Softmax 주변에서 스케일이 깨짐
    • add/cat 등 합류 지점에서 스케일 정렬이 안 됨
  4. 정밀도 혼합 정책이 잘못됨

    • 일부 민감 레이어는 FP16 또는 FP32로 남겨야 하는데 전부 INT8로 강제

이제부터는 “어디서부터 봐야 가장 빨리 원인을 찾는지” 순서대로 설명합니다.

1) 베이스라인부터 고정: FP32, FakeQuant, INT8를 분리 측정

정확도 하락을 잡으려면 세 가지 정확도를 분리해서 봐야 합니다.

  • FP32 원본 모델 정확도
  • QAT 준비 후 FakeQuant(가짜 양자화) 상태에서의 정확도
  • convert 후 실제 INT8 모델 정확도

여기서 중요한 관찰:

  • FakeQuant에서도 정확도가 크게 떨어지면: 학습/옵저버/그래프 구성 문제
  • FakeQuant는 괜찮은데 INT8에서만 떨어지면: convert/백엔드/지원 연산자/폴딩 문제 가능성이 큼

아래는 최소한의 측정 뼈대입니다.

import torch
import torch.ao.quantization as tq

@torch.no_grad()
def evaluate(model, dataloader, device="cpu"):
    model.eval()
    correct = 0
    total = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

# 1) FP32
acc_fp32 = evaluate(model_fp32, val_loader)

# 2) FakeQuant(QAT 준비 후)
acc_fake = evaluate(model_qat_prepared, val_loader)

# 3) INT8(convert 후)
acc_int8 = evaluate(model_int8, val_loader)

print({"fp32": acc_fp32, "fake": acc_fake, "int8": acc_int8})

2) QAT의 정석 파이프라인: fuse → prepare_qat → train → convert

정확도 하락 이슈의 상당수는 순서가 틀리거나 중간 상태 모델을 재사용하면서 생깁니다.

권장 순서:

  1. eval() 상태에서 fuse_modules
  2. train() 상태로 전환
  3. qconfig 설정
  4. prepare_qat
  5. QAT 파인튜닝
  6. eval() 전환
  7. convert
import copy
import torch
import torch.ao.quantization as tq

# 백엔드 선택: x86 서버면 보통 fbgemm
torch.backends.quantized.engine = "fbgemm"

model = copy.deepcopy(model_fp32)
model.eval()

# 예: Conv+BN+ReLU fuse (모델 구조에 맞게 수정)
# tq.fuse_modules는 모듈 경로 문자열 리스트를 받음
model = tq.fuse_modules(model, [["conv", "bn", "relu"]], inplace=False)

# QAT 준비
model.train()
model.qconfig = tq.get_default_qat_qconfig("fbgemm")
model_prepared = tq.prepare_qat(model, inplace=False)

# QAT 파인튜닝(짧게라도 필수)
optimizer = torch.optim.SGD(model_prepared.parameters(), lr=1e-4, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(3):
    model_prepared.train()
    for x, y in train_loader:
        optimizer.zero_grad()
        out = model_prepared(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

# 변환
model_prepared.eval()
model_int8 = tq.convert(model_prepared, inplace=False)

체크 포인트:

  • fuse_modules는 반드시 eval()에서 수행하는 것이 일반적입니다.
  • prepare_qat 이후에 qconfig를 바꾸면 의미가 없습니다. 바꿨다면 다시 준비해야 합니다.
  • 학습이 너무 짧으면 옵저버 통계가 안정화되지 않아 INT8 스케일이 불안정할 수 있습니다.

3) 옵저버가 문제인지 확인: activation 분포에 맞는 observer 선택

정확도 하락이 큰데 학습 로그는 멀쩡하다면, 종종 activation 스케일이 망가진 것입니다.

대표적으로:

  • 분포가 긴 꼬리(outlier)를 가지면 MinMaxObserver는 스케일이 과도하게 커져 유효 비트가 줄어듭니다.
  • 이때는 HistogramObserver 또는 MovingAverageMinMaxObserver가 더 안정적인 경우가 많습니다.

커스텀 qconfig 예시:

import torch.ao.quantization as tq

activation_observer = tq.HistogramObserver.with_args(
    reduce_range=False
)

weight_observer = tq.default_per_channel_weight_observer

qconfig = tq.QConfig(
    activation=activation_observer,
    weight=weight_observer
)

model.qconfig = qconfig
model_prepared = tq.prepare_qat(model, inplace=False)

추가 팁:

  • CNN 계열은 weight per-channel이 정확도에 크게 유리한 경우가 많습니다.
  • 반대로 일부 모바일/특정 백엔드에서는 per-tensor만 지원하거나 성능/정확도 트레이드오프가 달라집니다.

4) 옵저버 freeze 타이밍: 너무 빨리 얼리면 망한다

QAT에서는 보통 학습 중간에 다음을 수행합니다.

  • 옵저버 업데이트 중지(disable_observer)
  • fake quant 고정(freeze_bn_stats 또는 BN 관련 처리)

너무 이르게 freeze하면, 아직 분포가 안정화되지 않아 잘못된 스케일로 굳어버립니다.

import torch.ao.quantization as tq

def set_qat_freeze(model):
    # 옵저버 비활성화
    model.apply(tq.disable_observer)
    # FakeQuant 고정(스케일/제로포인트 업데이트 중지)
    model.apply(tq.disable_fake_quant)

# 예: 전체 epoch 중 후반부에만 freeze
for epoch in range(10):
    train_one_epoch(model_prepared)
    if epoch == 7:
        model_prepared.apply(tq.disable_observer)
    if epoch == 8:
        model_prepared.apply(tq.disable_fake_quant)

주의:

  • 프로젝트에 따라 freeze 정책이 다릅니다. 핵심은 스케일 통계가 충분히 수렴한 뒤에 고정하는 것입니다.

5) 민감 레이어는 INT8로 보내지 말고 “선별적으로” 남겨라

정확도 하락이 특정 블록에서만 발생한다면, 전부를 INT8로 만들기보다 일부 연산은 FP32로 유지하는 전략이 효과적입니다.

실무에서 자주 민감한 구간:

  • 입력단/출력단(첫 Conv, 마지막 FC)
  • Attention, Softmax 주변
  • 작은 채널 수에서의 depthwise 계열

선별적으로 qconfig = None을 주는 패턴:

import torch.ao.quantization as tq

# 예: 마지막 분류기 레이어는 FP32 유지
model.classifier.qconfig = None

# 또는 특정 서브모듈 전체를 제외
model.head.qconfig = None

model_prepared = tq.prepare_qat(model, inplace=False)

포인트:

  • “정확도에 민감한 곳만 FP32로 남기기”는 속도 이득을 크게 해치지 않으면서 품질을 회복하는 경우가 많습니다.

6) 합류 지점(add/cat)에서 스케일이 깨질 때: QuantStub/DeQuantStub 위치 점검

ResNet류처럼 add가 많은 네트워크는 합류 지점에서 두 텐서의 스케일/제로포인트 정렬이 관건입니다.

일반적으로는 QuantStubDeQuantStub을 모델 입출력에 배치하고, 내부는 FX 그래프 모드가 더 안정적인 편입니다. 다만 eager mode를 쓴다면 스텁 위치가 어긋나 floatquantized가 섞여 예상치 못한 디퀀트가 발생할 수 있습니다.

간단한 스텁 예시:

import torch
import torch.nn as nn
import torch.ao.quantization as tq

class QuantWrapper(nn.Module):
    def __init__(self, core):
        super().__init__()
        self.quant = tq.QuantStub()
        self.core = core
        self.dequant = tq.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.core(x)
        x = self.dequant(x)
        return x

model = QuantWrapper(model_fp32)

증상 기반 힌트:

  • FakeQuant는 괜찮은데 INT8만 급락한다면, 합류 지점에서의 quantize/dequantize 삽입이 달라졌을 가능성이 있습니다.

7) 백엔드/커널 문제: fbgemm vs qnnpack 불일치

QAT는 “학습 때 가정한 양자화 스킴”과 “실제 실행 커널”이 맞아야 합니다.

  • 서버 x86: 보통 fbgemm
  • 모바일 ARM: 보통 qnnpack

학습/변환/서빙 환경이 다르면 정확도뿐 아니라 출력 자체가 달라 보일 수 있습니다.

import torch

torch.backends.quantized.engine = "fbgemm"  # 또는 "qnnpack"
print(torch.backends.quantized.engine)

운영에서 환경 차이로 문제가 생기는 패턴은 다른 분야에서도 흔합니다. 예를 들어 EKS에서 노드/보안 설정 차이로 특정 기능만 실패하는 경우처럼요. 비슷한 트러블슈팅 접근은 EKS에서 kubectl exec·logs가 안 될 때 진단법 같은 글의 “환경-권한-경로를 나눠서 확인”하는 방식이 참고가 됩니다.

8) 데이터 전처리 불일치: INT8에서만 더 크게 터진다

QAT 자체보다 더 자주 놓치는 것이 전처리/정규화 불일치입니다.

  • FP32 학습/평가에서는 약간의 전처리 차이가 티가 덜 나는데
  • INT8에서는 activation 범위가 제한되면서 작은 불일치가 더 큰 손실로 증폭됩니다.

체크리스트:

  • 정규화(mean/std)가 학습과 동일한가
  • 입력 dtype이 float32로 들어가고 있는가
  • 채널 순서(NCHW) 및 스케일(0-1, 0-255)이 동일한가

간단한 가드 코드:

def assert_input_ok(x):
    assert x.dtype == torch.float32
    assert x.ndim == 4
    # 값 범위 점검(예: 0-1 가정)
    assert x.min().item() >= -0.1 and x.max().item() <= 1.1

9) 디버깅 실전: 어디 레이어에서 오차가 폭발하는지 찾기

정확도 하락을 “감”으로 때려 맞추지 말고, 레이어별로 FP32 출력과 FakeQuant 또는 INT8 출력 차이를 측정해 병목을 찾는 게 빠릅니다.

훅 기반 간단 비교:

import torch

def collect_activations(model, layer_names):
    acts = {}
    hooks = []

    name_to_module = dict(model.named_modules())

    def make_hook(name):
        def hook(m, inp, out):
            # out이 tuple일 수 있어 방어
            t = out[0] if isinstance(out, (tuple, list)) else out
            acts[name] = t.detach().float().cpu()
        return hook

    for n in layer_names:
        hooks.append(name_to_module[n].register_forward_hook(make_hook(n)))

    return acts, hooks

@torch.no_grad()
def compare_one_batch(model_a, model_b, x, layer_names):
    acts_a, hooks_a = collect_activations(model_a, layer_names)
    acts_b, hooks_b = collect_activations(model_b, layer_names)

    _ = model_a(x)
    _ = model_b(x)

    for h in hooks_a + hooks_b:
        h.remove()

    diffs = {}
    for n in layer_names:
        a = acts_a[n]
        b = acts_b[n]
        diffs[n] = (a - b).abs().mean().item()
    return diffs

# 사용 예: 중요한 블록 몇 개만 찍어도 방향이 보임
layer_names = [
    "core.layer1.0.conv1",
    "core.layer2.0.conv1",
    "core.layer4.1.conv2",
]

diffs = compare_one_batch(model_fp32.eval(), model_prepared.eval(), x_batch, layer_names)
print(diffs)

이렇게 “오차가 급증하는 지점”을 찾은 뒤, 그 블록만 qconfig=None으로 제외하거나 observer를 바꾸는 식으로 빠르게 수렴시킬 수 있습니다.

10) 흔한 처방전(우선순위 순)

정확도 하락을 만났을 때, 아래 순서로 적용하면 시행착오가 줄어듭니다.

  1. 측정 분리: FP32 vs FakeQuant vs INT8
  2. 파이프라인 순서 고정: fuseprepare_qat, 그리고 변환 전 eval
  3. 백엔드 통일: torch.backends.quantized.engine 확인
  4. 옵저버 교체: MinMax에서 Histogram 또는 MovingAverage
  5. freeze 타이밍 늦추기: 후반부에 observer/fakequant 고정
  6. 민감 레이어 제외: 마지막 FC, 첫 Conv, 특정 블록을 FP32 유지
  7. 전처리 재검증: 정규화/스케일/채널 순서
  8. 레이어별 오차 추적: 훅으로 폭발 지점 찾기

마무리

PyTorch QAT에서 INT8 변환 후 정확도 하락은 “QAT가 원래 어렵다”기보다, 대부분 설정과 상태의 불일치, 그리고 옵저버 통계의 부정확에서 시작합니다.

핵심은 한 번에 여러 가설을 건드리지 말고, FakeQuantINT8의 차이를 분리해 관측한 뒤, 옵저버/백엔드/민감 레이어를 작게, 반복적으로 조정하는 것입니다.

다음 단계로는 모델이 CNN인지 Transformer인지, 그리고 목표 런타임이 torchscript인지 onnxruntime인지에 따라 최적 처방이 달라집니다. 원하시면 모델 종류와 현재 qconfig, 정확도 하락 폭(top1, mAP 등), 그리고 FakeQuant 단계 정확도를 기준으로 더 구체적인 튜닝 플랜까지 정리해드릴 수 있습니다.