Published on

PyTorch PTQ·QAT 정확도 급락 원인·복구 7단계

Authors

서빙 지연을 줄이거나 엣지 디바이스에 배포하기 위해 PyTorch 양자화(Quantization)를 적용했는데, PTQ(Post-Training Quantization)나 QAT(Quantization Aware Training) 이후 정확도가 갑자기 무너지는 경우가 흔합니다. 문제는 대부분 “양자화 자체”가 아니라, 관측기(observer) 통계 수집, 스케일링, 연산자 교체, 데이터 전처리, 평가 모드, 백엔드 설정 같은 주변 조건의 작은 불일치에서 시작됩니다.

이 글은 PTQ와 QAT에서 정확도 급락을 일으키는 전형적인 원인을 정리하고, 실제로 복구까지 이어지는 7단계 절차를 제공합니다. 원인 진단은 체크리스트형으로, 복구는 바로 적용 가능한 PyTorch 코드 중심으로 설명합니다.

비슷한 문제를 “원인별로 쪼개서 체크리스트로 복구”하는 접근은 인프라 장애 진단에도 유효합니다. 예를 들어 서비스가 계속 재시작될 때도 같은 방식으로 원인을 좁힙니다. 참고로 systemd 서비스가 계속 재시작될 때 진단 체크리스트 글의 구조가 이런 문제 해결 흐름과 닮아 있습니다.

먼저: PTQ와 QAT에서 정확도 급락이 생기는 대표 패턴

아래 패턴 중 하나라도 해당되면, “양자화가 잘못됐다”가 아니라 “양자화 파이프라인의 전제 조건이 깨졌다”일 가능성이 큽니다.

  • 캘리브레이션(calibration) 데이터가 너무 적거나, 실제 추론 분포와 다름
  • model.eval() 누락, BatchNorm 통계가 흔들림, Dropout이 켜진 상태로 관측기 통계 수집
  • per-tensor 스케일로 인해 채널별 분포가 큰 Conv가 심하게 손상됨(per-channel 필요)
  • activation 범위가 긴 모델(Transformer, detection head 등)에서 기본 MinMax observer가 과도하게 클리핑되거나 반대로 너무 넓게 잡힘
  • 양자화 가능한 연산자 패턴이 끊겨서(예: Conv + BN + ReLU가 분리됨) 기대한 fuse가 안 됨
  • 백엔드(fbgemm, qnnpack) 불일치 또는 지원 연산자 차이
  • QAT에서 fake quant를 너무 늦게/너무 이르게 켜거나, 학습률 스케줄이 양자화 노이즈를 못 따라감

이제부터는 이 패턴을 “7단계 복구 절차”로 구체화합니다.

7단계 복구 절차(PTQ·QAT 공통)

1단계: 베이스라인을 고정하고, 비교 실험이 가능한지 확인

정확도 급락을 논하기 전에, 다음이 고정되어야 합니다.

  • 동일한 전처리/후처리
  • 동일한 평가 코드
  • 동일한 체크포인트
  • 동일한 랜덤 시드(가능하면)

특히 PTQ는 캘리브레이션 데이터 순서가 통계에 영향을 주기도 합니다. 아래처럼 최소한의 베이스라인 평가 코드를 만들어 FP32와 양자화 모델을 같은 루프로 비교하세요.

import torch

def evaluate(model, dataloader, device="cuda"):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()
    return correct / max(total, 1)

체크 포인트: FP32 정확도가 흔들린다면 양자화 문제가 아니라 평가 파이프라인 문제일 수 있습니다.

2단계: 캘리브레이션 데이터 분포를 “추론 분포”에 맞추기

PTQ에서 가장 흔한 원인은 캘리브레이션 데이터가 너무 적거나, 실제 서비스 입력과 분포가 다른 것입니다. 예를 들어 학습 데이터 일부를 무작위로 뽑았는데 서비스에서는 야간/저조도 이미지가 많다면 activation 범위가 달라져 양자화 스케일이 틀어집니다.

권장 기준(경험칙):

  • 분류 모델: 최소 수백~수천 샘플
  • detection/segmentation: 최소 수천 샘플, 특히 다양한 해상도/스케일 포함
  • NLP: 길이 분포(짧은 문장만 캘리브레이션하면 긴 문장에서 망가짐)를 맞추기

캘리브레이션 루프는 반드시 eval 모드에서, no_grad로 실행하세요.

def calibrate(model, calib_loader, device="cpu"):
    model.eval()
    with torch.no_grad():
        for x, _ in calib_loader:
            model(x.to(device))

체크 포인트: 캘리브레이션을 “학습 데이터”가 아니라 “실제 추론 트래픽 샘플”로 하는 것만으로도 급락이 복구되는 경우가 많습니다.

3단계: fuse가 제대로 되었는지, 양자화 가능한 그래프인지 점검

정확도뿐 아니라 성능도 목표라면, Conv + BN + ReLU 같은 패턴은 fuse 되어야 합니다. fuse가 되지 않으면 관측기 삽입 위치가 달라지고, 스케일링이 예상과 달라져 정확도에도 영향을 줍니다.

전형적인 실수:

  • nn.Sequential 밖에서 연산을 분리해 fuse 패턴이 끊김
  • F.relu 같은 functional 호출과 모듈 호출이 섞여 추적이 꼬임

예시(가능한 형태로 구성):

import torch.nn as nn
import torch.ao.quantization as aq

class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

model = Block(3, 16)
model.eval()

# fuse
model_fused = torch.ao.quantization.fuse_modules(model, [["conv", "bn", "relu"]], inplace=False)

체크 포인트: fuse 이후 FP32 정확도가 미세하게라도 변하면, 원래 모델이 eval/train 모드에 따라 동작이 달랐거나 BN이 불안정했을 수 있습니다.

4단계: qconfig와 observer를 모델 특성에 맞게 교체

기본 설정이 항상 최선은 아닙니다. 특히 activation 분포가 긴 모델에서 MinMax 기반 observer는 outlier에 취약합니다. 이때 histogram 기반이나 percentile 기반(직접 구현) 접근이 더 안정적일 수 있습니다.

또한 Conv 계열은 per-channel weight quant가 정확도 방어에 매우 중요합니다.

import torch
import torch.ao.quantization as aq

backend = "fbgemm"  # x86 서버
# backend = "qnnpack"  # ARM

torch.backends.quantized.engine = backend

# 기본 qconfig (backend에 따라 권장 observer가 다름)
qconfig = aq.get_default_qconfig(backend)

# 예: per-channel weight가 포함된 qconfig가 설정되는지 확인
print(qconfig)

추가 팁:

  • CNN: per-channel weight는 사실상 필수에 가깝습니다.
  • Transformer 계열: activation quant 방식 자체가 민감합니다. 가능하면 먼저 Linear만 양자화하거나, 특정 블록만 부분 양자화부터 시작하세요.

체크 포인트: “전부 양자화”가 아니라 “정확도 민감 레이어 제외” 전략이 실전에서 자주 쓰입니다.

5단계: PTQ 파이프라인을 표준 순서로 재구성(prepare, calibrate, convert)

PyTorch eager mode PTQ의 기본 흐름은 다음 순서가 안전합니다.

  1. eval 전환
  2. fuse
  3. qconfig 지정
  4. prepare
  5. 캘리브레이션(관측기 통계 수집)
  6. convert
  7. 양자화 모델 평가
import copy
import torch.ao.quantization as aq

fp32_model = model_fused
fp32_model.eval()

ptq_model = copy.deepcopy(fp32_model)
ptq_model.qconfig = aq.get_default_qconfig(torch.backends.quantized.engine)

prepared = aq.prepare(ptq_model, inplace=False)
calibrate(prepared, calib_loader, device="cpu")

quantized = aq.convert(prepared, inplace=False)

자주 하는 실수:

  • prepare 전에 qconfig를 안 넣음
  • 캘리브레이션을 train 모드로 돌림
  • 캘리브레이션을 GPU에서 돌리고 convert는 CPU에서 돌리는 등 디바이스 이동이 섞임

체크 포인트: 양자화 모델은 보통 CPU 추론을 전제로 합니다. GPU에서 양자화 이득을 기대하면 다른 스택이 필요합니다.

6단계: QAT는 “학습 스케줄”과 “BN 처리”를 같이 설계

QAT에서 정확도 급락이 나는 경우는 크게 두 가지입니다.

  • fake quant 노이즈가 들어오는데 학습률/스케줄이 이를 흡수하지 못함
  • BatchNorm이 계속 업데이트되며 분포가 흔들리고, fake quant 통계와 충돌

권장 접근:

  • QAT 시작 시점에 LR을 낮추고, 짧은 warmup을 둠
  • BN을 freeze 하거나(상황에 따라), 최소한 QAT 중에는 BN 통계를 안정화

PyTorch QAT 기본 흐름 예시:

import copy
import torch
import torch.ao.quantization as aq

qat_model = copy.deepcopy(model_fused)
qat_model.train()

qat_model.qconfig = aq.get_default_qat_qconfig(torch.backends.quantized.engine)
qat_prepared = aq.prepare_qat(qat_model, inplace=False)

optimizer = torch.optim.AdamW(qat_prepared.parameters(), lr=1e-4)

for epoch in range(3):
    qat_prepared.train()
    for x, y in train_loader:
        optimizer.zero_grad(set_to_none=True)
        logits = qat_prepared(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()

qat_prepared.eval()
qat_converted = aq.convert(qat_prepared, inplace=False)

BN freeze의 한 방법(간단 버전):

import torch.nn as nn

def freeze_bn(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
        for p in m.parameters():
            p.requires_grad = False

qat_prepared.apply(freeze_bn)

체크 포인트: QAT는 “몇 epoch 더 학습”이 아니라 “양자화 노이즈에 맞는 미세조정”입니다. 학습률을 FP32 학습과 동일하게 두면 급락하는 경우가 많습니다.

7단계: 레이어별 민감도 분석으로 “부분 양자화”와 예외 처리를 확정

마지막 단계는 원인을 좁힌 뒤에도 남는 손실을 줄이는 실전 단계입니다.

  • 특정 head(예: detection cls/reg head)만 FP32 유지
  • 첫 Conv, 마지막 FC는 FP32 유지
  • Softmax 직전 logits 스케일이 민감하면 그 주변만 FP32 유지

Eager mode에서 완벽한 자동 부분 양자화는 까다롭지만, 모듈 단위로 qconfig = None 처리하는 방식이 흔합니다.

def disable_quant_for_module(module):
    module.qconfig = None

# 예: classifier는 FP32 유지
if hasattr(ptq_model, "classifier"):
    disable_quant_for_module(ptq_model.classifier)

또 다른 방법은 아예 양자화 대상 블록을 QuantStub/DeQuantStub로 감싸 범위를 제한하는 것입니다.

import torch.nn as nn
import torch.ao.quantization as aq

class QuantWrap(nn.Module):
    def __init__(self, core):
        super().__init__()
        self.quant = aq.QuantStub()
        self.core = core
        self.dequant = aq.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.core(x)
        x = self.dequant(x)
        return x

체크 포인트: “정확도 0.5% 손해로 2배 가속”이 목표라면, 전부 양자화보다 부분 양자화가 더 빨리 목표를 달성합니다.

디버깅 팁: 관측기 통계와 스케일을 눈으로 확인

정확도 급락이 생겼다면, 실제로 어떤 레이어의 activation 범위가 비정상인지 확인하는 게 빠릅니다. prepare 이후 모듈에 observer가 붙어 있으므로, 특정 모듈의 observer를 찾아 min/max 또는 scale 정보를 출력해 보세요.

def print_observers(m, prefix=""):
    for name, child in m.named_children():
        full = f"{prefix}.{name}" if prefix else name
        # observer는 보통 activation_post_process 이름으로 달림
        if hasattr(child, "activation_post_process"):
            obs = child.activation_post_process
            print(full, type(obs))
        print_observers(child, full)

print_observers(prepared)

여기서 특정 블록만 유독 observer가 과도한 범위를 잡거나, 반대로 너무 좁게 잡으면 해당 블록이 급락의 진원지일 가능성이 큽니다.

PTQ vs QAT 선택 가이드(정확도 급락 관점)

  • PTQ로 목표 정확도를 맞출 수 있으면 PTQ가 운영 난이도가 훨씬 낮습니다.
  • PTQ에서 급락이 큰 모델(특히 activation 분포가 민감한 구조)은 QAT가 사실상 필수인 경우가 많습니다.
  • QAT는 비용이 크므로, 먼저 PTQ를 “캘리브레이션 개선 + observer 조정 + 부분 양자화”까지 해보고, 그래도 안 되면 QAT로 넘어가는 순서가 안전합니다.

운영 관점 체크리스트: 재현 가능한 실험 기록 남기기

양자화는 모델 구조뿐 아니라 데이터, 백엔드, PyTorch 버전, 심지어 CPU 아키텍처에 따라 결과가 달라집니다. 아래 항목을 실험 로그에 남기면, 급락이 재발했을 때 복구 속도가 크게 빨라집니다.

  • PyTorch 버전, torch.backends.quantized.engine
  • 캘리브레이션 샘플 수와 샘플링 방식
  • fuse 적용 여부
  • qconfig 종류(PTQ, QAT)
  • 제외한 모듈 목록(부분 양자화)

이런 식의 “원인별 체크”는 다른 분야에서도 동일하게 통합니다. 예를 들어 네트워크 라우팅이 안 될 때도 체크리스트로 빠르게 좁힙니다. 참고로 GCP VPC 피어링 라우트 안됨? 9가지 체크도 같은 문제 해결 스타일입니다.

마무리: 정확도 급락은 대부분 복구 가능하다

PTQ/QAT 이후 정확도 급락은 흔하지만, 대개는 다음 3가지 축에서 해결됩니다.

  • 통계 수집이 올바른가(캘리브레이션, eval 모드, 데이터 분포)
  • 그래프가 양자화 친화적인가(fuse, 연산자 패턴, 백엔드)
  • 모델 특성에 맞는 스케일링 전략인가(observer, per-channel, 부분 양자화)

위 7단계를 순서대로 적용하면 “왜 떨어졌는지”를 설명할 수 있는 상태로 만들고, 대부분의 케이스에서 정확도를 실용 수준까지 되돌릴 수 있습니다. 만약 특정 아키텍처(예: Transformer, detection)에서 계속 막힌다면, 다음 글로는 레이어별 민감도 분석 자동화, 부분 양자화 템플릿, 그리고 torch.export 기반 양자화 흐름까지 확장하는 것을 추천합니다.