Published on

PyTorch PTQ 정확도 급락? Calib 튜닝 실전

Authors

INT8 PTQ(Post-Training Quantization)를 적용했더니 모델 크기와 지연 시간은 줄었는데, 정확도가 눈에 띄게 급락하는 경우가 흔합니다. 특히 PyTorch PTQ는 calibration 단계에서 수집되는 통계(activation min/max 혹은 histogram)가 양자화 스케일을 사실상 결정하기 때문에, Calib이 조금만 어긋나도 전체 성능이 무너질 수 있습니다.

이 글은 “왜 떨어졌는지”를 빠르게 좁히고, “무엇을 어떻게 튜닝해야 하는지”를 PyTorch FX Graph Mode 기준으로 정리합니다. (Eager Mode도 원리는 동일합니다.)

관련해서 ONNX로 내보내 INT8까지 연결하는 전체 파이프라인이 필요하면 아래 글도 함께 보면 흐름이 이어집니다.

PTQ 정확도 급락의 80%: Calib 통계가 틀렸다

PTQ는 학습 없이(또는 아주 제한적인 보정만으로) float32 모델을 int8로 바꿉니다. 이때 가장 중요한 건 각 텐서(특히 activation)의 분포를 얼마나 “현실과 가깝게” 관측했는지입니다.

정확도 급락을 유발하는 대표적인 Calib 실패 패턴은 다음과 같습니다.

  1. Calib 데이터가 실제 추론 분포를 대표하지 못함
    • 학습 데이터 일부를 대충 샘플링했는데, 실제 서비스 입력(길이, 밝기, 도메인)이 다름
    • 이미지라면 resize/crop/normalize가 추론과 다름
    • NLP라면 토크나이저, padding, max length가 다름
  2. Calib 배치 수가 부족
    • min/max 기반 옵저버는 outlier 한 번에 스케일이 망가짐
    • histogram 기반도 bin이 안정화되려면 샘플이 필요
  3. 옵저버(Observer) 선택이 모델 특성과 불일치
    • outlier가 많은 activation에 per-tensor min/max를 쓰면 정보 손실이 큼
  4. per-channel/per-tensor 설정 미스
    • Conv/Linear weight는 보통 per-channel이 유리한데 per-tensor로 묶어버림
  5. 양자화에서 취약한 연산이 그대로 int8로 내려감
    • Softmax, LayerNorm, attention score 같은 구간은 quant-friendly하지 않음
    • 이 경우는 “Calib만”으로는 한계가 있어 부분적으로 FP로 남기는 전략이 필요

이제부터는 “어디서” 깨지는지 확인하고, “Calib”을 어떻게 튜닝할지 단계별로 들어가겠습니다.

1) 먼저 재현 가능한 측정 루프를 만든다

Calib 튜닝은 감으로 하면 끝이 없습니다. 아래 3가지를 고정해두면 원인 규명이 빨라집니다.

  • 동일한 eval 데이터셋(혹은 서비스 로그 샘플)
  • 동일한 전처리 파이프라인
  • 동일한 metric 계산 코드

아래는 분류 문제 기준의 최소 평가 루프 예시입니다.

import torch

def evaluate_top1(model, dataloader, device="cuda"):
    model.eval()
    correct = 0
    total = 0
    with torch.inference_mode():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()
    return correct / max(total, 1)

이제 FP32 baselinePTQ INT8를 같은 루프로 비교하면서, Calib 설정을 바꿀 때마다 수치가 어떻게 움직이는지 확인합니다.

2) PyTorch FX PTQ 기본 파이프라인(prepare/calibrate/convert)

FX Graph Mode(권장)는 대략 다음 흐름입니다.

  1. prepare_fx: 옵저버 삽입(통계 수집 준비)
  2. calibration: 대표 데이터로 forward를 여러 번 돌려 통계 수집
  3. convert_fx: 수집된 통계로 quant/dequant 노드를 확정하고 int8 커널로 변환
import torch
import torch.ao.quantization as tq
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

backend = "fbgemm"  # x86 서버
# backend = "qnnpack"  # ARM/모바일

torch.backends.quantized.engine = backend

# 예시: ResNet류는 기본 qconfig가 어느 정도 동작
qconfig = tq.get_default_qconfig(backend)
qconfig_dict = {"": qconfig}

example_inputs = (torch.randn(1, 3, 224, 224),)

model_fp32 = model_fp32.eval()
prepared = prepare_fx(model_fp32, qconfig_dict, example_inputs)

# --- calibration ---
with torch.inference_mode():
    for i, (x, _) in enumerate(calib_loader):
        prepared(x)
        if i >= 200:  # 시작점: 200~1000 배치 사이에서 튜닝
            break

model_int8 = convert_fx(prepared)

정확도가 급락한다면, 대부분은 위 코드의 qconfigcalibration loop를 튜닝해야 합니다.

3) Calib 데이터 튜닝: “개수”보다 “대표성”

전처리 불일치부터 잡는다

Calib에 들어가는 입력은 “추론 입력과 동일한 전처리”여야 합니다. 특히 아래 항목이 자주 어긋납니다.

  • 이미지: normalize(mean/std), resize 방식, center crop 여부
  • NLP: tokenizer 버전, max_length, truncation, padding 정책
  • 오디오: sample rate, window, mel 파라미터

전처리가 다르면 activation 분포 자체가 달라져서, 옵저버가 완전히 다른 스케일을 학습합니다.

Calib 샘플 수 가이드

정답은 모델/데이터에 따라 다르지만, 실전에서 자주 쓰는 시작점은 다음과 같습니다.

  • vision/classification: 500~2,000 샘플
  • detection/segmentation: 1,000~5,000 샘플 (분포가 더 다양)
  • NLP: 수백~수천 문장 (길이 분포를 반드시 포함)

단, “많이”보다 중요한 건 outlier를 포함하되, outlier에 지배당하지 않게 하는 것입니다. 이 지점에서 옵저버 선택이 중요해집니다.

4) 옵저버 튜닝: MinMax vs Histogram vs PerChannel

PyTorch PTQ에서 activation 옵저버는 대체로 다음 선택지로 좁혀집니다.

  • MinMaxObserver: 가장 단순, outlier에 취약
  • HistogramObserver: 분포를 히스토그램으로 보고, outlier 영향 완화 가능
  • PerChannelMinMaxObserver: 주로 weight에 사용(Conv/Linear)

(권장) Weight는 per-channel을 기본으로

Conv/Linear weight는 per-channel이 정확도에 유리한 경우가 많습니다. 기본 qconfig가 이를 어느 정도 반영하지만, 커스텀 설정이 필요할 때가 있습니다.

import torch.ao.quantization as tq

act_observer = tq.HistogramObserver.with_args(
    dtype=torch.quint8,
    qscheme=torch.per_tensor_affine,
    reduce_range=False,
)

wt_observer = tq.PerChannelMinMaxObserver.with_args(
    dtype=torch.qint8,
    qscheme=torch.per_channel_symmetric,
    ch_axis=0,
)

qconfig = tq.QConfig(activation=act_observer, weight=wt_observer)
qconfig_dict = {"": qconfig}

위 조합은 “activation은 histogram으로 outlier를 완화”하고 “weight는 per-channel로 표현력을 확보”하는 전형적인 튜닝 레시피입니다.

reduce_range는 상황 따라

reduce_range=True는 일부 백엔드에서 오버플로우/호환성 이슈를 피하려고 쓰지만, 표현 범위를 줄여 정확도에 영향을 줄 수 있습니다. 최근 환경에서는 보통 False로 두고 시작한 뒤, 필요할 때만 켭니다.

5) Calib 루프 튜닝: dropout, BN, dynamic shape

반드시 eval() + inference_mode()

Calib은 통계 수집이 목적이므로, 학습 시 랜덤성이 들어가면 분포가 흔들립니다.

  • model.eval() 필수 (Dropout off, BN freeze)
  • torch.inference_mode() 권장

BatchNorm이 특히 민감한 이유

BN은 eval()에서 러닝 스탯을 사용하므로, Calib 시점에 BN 자체는 업데이트되지 않습니다. 즉, BN이 이미 학습된 상태라면 Calib에서 BN을 “바꾸는” 게 아니라, BN 출력 분포를 “관측”하는 과정입니다.

문제는 다음입니다.

  • BN 출력이 특정 범위로 잘 정규화되어 있을 거라 기대했는데
  • 입력 전처리/도메인이 달라 BN 출력 분포가 크게 치우치면
  • activation quantization이 급격히 손실을 유발

이 경우 Calib 데이터 대표성을 올리는 게 1순위이고, 그 다음이 옵저버를 histogram으로 바꾸는 것입니다.

dynamic shape 입력은 길이 분포를 섞어라

NLP나 가변 해상도 입력은 “짧은 것만” 혹은 “긴 것만”으로 Calib하면 특정 길이에서만 잘 동작합니다. 길이/해상도 버킷을 나눠 샘플링하는 방식이 안정적입니다.

6) 어디서 깨지는지 찾기: 레이어별 민감도 점검

정확도 급락이 발생하면 “특정 블록”이 원인인 경우가 많습니다. 아래처럼 레이어별로 출력 통계를 찍으면 힌트를 얻을 수 있습니다.

import torch

def add_activation_hooks(model, names_to_watch):
    stats = {}
    hooks = []

    def make_hook(name):
        def hook(module, inp, out):
            t = out.detach()
            stats[name] = {
                "min": float(t.min()),
                "max": float(t.max()),
                "mean": float(t.mean()),
                "std": float(t.std()),
            }
        return hook

    for name, m in model.named_modules():
        if name in names_to_watch:
            hooks.append(m.register_forward_hook(make_hook(name)))

    return stats, hooks

# 사용 예
names = ["layer1.0.relu", "layer2.0.relu", "fc"]
stats, hooks = add_activation_hooks(model_fp32.eval(), names)
with torch.inference_mode():
    x, _ = next(iter(calib_loader))
    _ = model_fp32(x)
for h in hooks:
    h.remove()
print(stats)

관측 포인트:

  • 특정 레이어에서 max가 비정상적으로 크거나(극단 outlier)
  • std가 매우 작은데 범위만 큰 경우(희소 outlier)

이런 레이어는 MinMaxObserver에서 특히 취약합니다. HistogramObserver로 바꾸면 개선되는 경우가 많습니다.

7) “일부는 FP로 남기기”: 선택적 quantization 전략

모든 연산을 INT8로 내리는 게 항상 최선은 아닙니다. 특히 아래 연산은 모델에 따라 민감합니다.

  • attention score 계산 구간
  • softmax
  • layer norm
  • 작은 MLP에서의 마지막 projection

FX Graph Mode에서는 모듈 단위로 qconfig를 다르게 주어 “양자화 제외”를 만들 수 있습니다. 가장 단순한 방식은 특정 서브모듈에 None을 주는 것입니다.

import torch.ao.quantization as tq

qconfig_global = tq.get_default_qconfig("fbgemm")

qconfig_dict = {
    "": qconfig_global,
    # 예: 민감한 블록은 quantization 제외
    "module_name": [
        ("encoder.layernorm", None),
        ("head", None),
    ],
}

모듈 이름은 실제 모델 구조에 맞게 조정해야 합니다. “정확도 급락” 케이스는 보통 1~2개 구간만 FP로 돌려도 체감 성능이 크게 회복됩니다.

8) Calib 튜닝 체크리스트(실전 순서)

아래 순서대로 하면 시행착오를 줄일 수 있습니다.

  1. 전처리 100% 동일화: 학습/추론/Calib 파이프라인 비교
  2. Calib 데이터 대표성 확보: 서비스 로그 샘플이 있으면 최우선
  3. Calib 배치 수 증가: 200에서 시작해 1,000까지 올려보고 곡선 확인
  4. activation 옵저버를 histogram으로 변경: outlier 완화
  5. weight per-channel 확인: Conv/Linear는 per-channel이 유리한 경우 다수
  6. 민감 블록 FP 유지: layer norm/softmax 등부터 후보
  7. 백엔드 확인: fbgemm(x86) vs qnnpack(ARM) 결과가 다를 수 있음

9) 자주 묻는 함정: “PTQ는 원래 정확도 손실이 크지 않나요?”

모델/태스크에 따라 다릅니다.

  • ResNet 계열 분류 모델은 PTQ로도 손실이 작게 나오는 편
  • Transformer 계열, 특히 LayerNorm과 attention이 많은 구조는 PTQ만으로 손실이 커질 수 있음

후자의 경우 Calib 튜닝으로도 한계가 있으면, 다음 단계로는 QAT(Quantization Aware Training) 또는 GPTQ/RTN 같은 LLM 특화 기법을 고려합니다. 로컬 LLM 양자화 기법 비교는 아래 글이 참고됩니다.

10) 마무리: Calib은 “데이터”와 “옵저버”의 문제다

PyTorch PTQ에서 정확도 급락이 발생하면, 대부분은 다음 두 축으로 설명됩니다.

  • Calib 데이터가 실제 추론 분포를 대표하지 못했다
  • 옵저버/스케일 정책이 outlier에 취약하게 잡혔다

따라서 해결도 두 축에서 접근하는 게 가장 빠릅니다.

  • 서비스 입력과 동일한 전처리로 대표 샘플을 만들고
  • activation은 histogram 기반으로, weight는 per-channel로 시작하며
  • 민감한 블록은 과감히 FP로 남기는 방식으로 타협점을 찾으세요.

이 과정을 거치면 “INT8로 바꾸자마자 정확도 폭락” 같은 케이스는 상당수 안정적으로 복구됩니다.