- Published on
PyTorch QAT INT8 변환 후 정확도 하락 잡기
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서버 추론 비용을 줄이려고 PyTorch의 QAT(Quantization Aware Training)로 INT8 변환을 적용했는데, 변환 직후 정확도가 눈에 띄게 하락하는 경우가 자주 있습니다. QAT는 원래 PTQ(Post Training Quantization)보다 정확도 보존이 유리하지만, 설정 한두 군데만 어긋나도 학습은 잘 되는 듯 보이다가 변환 후 성능이 급격히 무너질 수 있습니다.
이 글에서는 torch.ao.quantization 기반 QAT 파이프라인을 전제로, 정확도 하락을 “원인별로” 빠르게 좁히는 진단 순서와 재현 가능한 수정 패턴을 정리합니다.
추가로, 운영에서 흔히 겪는 “캐시/설정 불일치로 인한 실패”와 유사하게 QAT도 옵저버/스케일/백엔드 설정 불일치가 핵심 원인인 경우가 많습니다. 그런 종류의 트러블슈팅 감각은 JWT 검증 실패 - JWKS kid 불일치·캐시 7가지 같은 글의 접근법과도 닮아 있습니다.
먼저 확인: QAT 정확도 하락의 80%는 “불일치”
QAT에서 정확도 하락이 크게 나는 패턴은 대체로 다음 중 하나입니다.
학습 시점과 변환 시점의 설정 불일치
qconfig를 바꿨는데 다시prepare_qat를 안 함fuse_modules전후 순서가 뒤섞임- 학습은
fbgemm가정인데 변환/실행은 다른 백엔드
옵저버(Observer) 캘리브레이션 품질 문제
- 학습 초기에 스케일이 요동치는데 freeze 타이밍이 너무 빠름
- 분포가 긴 꼬리인 텐서에
MinMaxObserver를 써서 스케일이 망가짐
연산자/레이어가 양자화 친화적이지 않음
SiLU,GELU,Softmax주변에서 스케일이 깨짐add/cat등 합류 지점에서 스케일 정렬이 안 됨
정밀도 혼합 정책이 잘못됨
- 일부 민감 레이어는
FP16또는FP32로 남겨야 하는데 전부INT8로 강제
- 일부 민감 레이어는
이제부터는 “어디서부터 봐야 가장 빨리 원인을 찾는지” 순서대로 설명합니다.
1) 베이스라인부터 고정: FP32, FakeQuant, INT8를 분리 측정
정확도 하락을 잡으려면 세 가지 정확도를 분리해서 봐야 합니다.
FP32원본 모델 정확도- QAT 준비 후
FakeQuant(가짜 양자화) 상태에서의 정확도 convert후 실제INT8모델 정확도
여기서 중요한 관찰:
FakeQuant에서도 정확도가 크게 떨어지면: 학습/옵저버/그래프 구성 문제FakeQuant는 괜찮은데INT8에서만 떨어지면: convert/백엔드/지원 연산자/폴딩 문제 가능성이 큼
아래는 최소한의 측정 뼈대입니다.
import torch
import torch.ao.quantization as tq
@torch.no_grad()
def evaluate(model, dataloader, device="cpu"):
model.eval()
correct = 0
total = 0
for x, y in dataloader:
x, y = x.to(device), y.to(device)
out = model(x)
pred = out.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.numel()
return correct / max(total, 1)
# 1) FP32
acc_fp32 = evaluate(model_fp32, val_loader)
# 2) FakeQuant(QAT 준비 후)
acc_fake = evaluate(model_qat_prepared, val_loader)
# 3) INT8(convert 후)
acc_int8 = evaluate(model_int8, val_loader)
print({"fp32": acc_fp32, "fake": acc_fake, "int8": acc_int8})
2) QAT의 정석 파이프라인: fuse → prepare_qat → train → convert
정확도 하락 이슈의 상당수는 순서가 틀리거나 중간 상태 모델을 재사용하면서 생깁니다.
권장 순서:
eval()상태에서fuse_modulestrain()상태로 전환qconfig설정prepare_qat- QAT 파인튜닝
eval()전환convert
import copy
import torch
import torch.ao.quantization as tq
# 백엔드 선택: x86 서버면 보통 fbgemm
torch.backends.quantized.engine = "fbgemm"
model = copy.deepcopy(model_fp32)
model.eval()
# 예: Conv+BN+ReLU fuse (모델 구조에 맞게 수정)
# tq.fuse_modules는 모듈 경로 문자열 리스트를 받음
model = tq.fuse_modules(model, [["conv", "bn", "relu"]], inplace=False)
# QAT 준비
model.train()
model.qconfig = tq.get_default_qat_qconfig("fbgemm")
model_prepared = tq.prepare_qat(model, inplace=False)
# QAT 파인튜닝(짧게라도 필수)
optimizer = torch.optim.SGD(model_prepared.parameters(), lr=1e-4, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(3):
model_prepared.train()
for x, y in train_loader:
optimizer.zero_grad()
out = model_prepared(x)
loss = criterion(out, y)
loss.backward()
optimizer.step()
# 변환
model_prepared.eval()
model_int8 = tq.convert(model_prepared, inplace=False)
체크 포인트:
fuse_modules는 반드시eval()에서 수행하는 것이 일반적입니다.prepare_qat이후에qconfig를 바꾸면 의미가 없습니다. 바꿨다면 다시 준비해야 합니다.- 학습이 너무 짧으면 옵저버 통계가 안정화되지 않아
INT8스케일이 불안정할 수 있습니다.
3) 옵저버가 문제인지 확인: activation 분포에 맞는 observer 선택
정확도 하락이 큰데 학습 로그는 멀쩡하다면, 종종 activation 스케일이 망가진 것입니다.
대표적으로:
- 분포가 긴 꼬리(outlier)를 가지면
MinMaxObserver는 스케일이 과도하게 커져 유효 비트가 줄어듭니다. - 이때는
HistogramObserver또는MovingAverageMinMaxObserver가 더 안정적인 경우가 많습니다.
커스텀 qconfig 예시:
import torch.ao.quantization as tq
activation_observer = tq.HistogramObserver.with_args(
reduce_range=False
)
weight_observer = tq.default_per_channel_weight_observer
qconfig = tq.QConfig(
activation=activation_observer,
weight=weight_observer
)
model.qconfig = qconfig
model_prepared = tq.prepare_qat(model, inplace=False)
추가 팁:
- CNN 계열은 weight per-channel이 정확도에 크게 유리한 경우가 많습니다.
- 반대로 일부 모바일/특정 백엔드에서는 per-tensor만 지원하거나 성능/정확도 트레이드오프가 달라집니다.
4) 옵저버 freeze 타이밍: 너무 빨리 얼리면 망한다
QAT에서는 보통 학습 중간에 다음을 수행합니다.
- 옵저버 업데이트 중지(
disable_observer) - fake quant 고정(
freeze_bn_stats또는 BN 관련 처리)
너무 이르게 freeze하면, 아직 분포가 안정화되지 않아 잘못된 스케일로 굳어버립니다.
import torch.ao.quantization as tq
def set_qat_freeze(model):
# 옵저버 비활성화
model.apply(tq.disable_observer)
# FakeQuant 고정(스케일/제로포인트 업데이트 중지)
model.apply(tq.disable_fake_quant)
# 예: 전체 epoch 중 후반부에만 freeze
for epoch in range(10):
train_one_epoch(model_prepared)
if epoch == 7:
model_prepared.apply(tq.disable_observer)
if epoch == 8:
model_prepared.apply(tq.disable_fake_quant)
주의:
- 프로젝트에 따라 freeze 정책이 다릅니다. 핵심은 스케일 통계가 충분히 수렴한 뒤에 고정하는 것입니다.
5) 민감 레이어는 INT8로 보내지 말고 “선별적으로” 남겨라
정확도 하락이 특정 블록에서만 발생한다면, 전부를 INT8로 만들기보다 일부 연산은 FP32로 유지하는 전략이 효과적입니다.
실무에서 자주 민감한 구간:
- 입력단/출력단(첫 Conv, 마지막 FC)
- Attention, Softmax 주변
- 작은 채널 수에서의 depthwise 계열
선별적으로 qconfig = None을 주는 패턴:
import torch.ao.quantization as tq
# 예: 마지막 분류기 레이어는 FP32 유지
model.classifier.qconfig = None
# 또는 특정 서브모듈 전체를 제외
model.head.qconfig = None
model_prepared = tq.prepare_qat(model, inplace=False)
포인트:
- “정확도에 민감한 곳만 FP32로 남기기”는 속도 이득을 크게 해치지 않으면서 품질을 회복하는 경우가 많습니다.
6) 합류 지점(add/cat)에서 스케일이 깨질 때: QuantStub/DeQuantStub 위치 점검
ResNet류처럼 add가 많은 네트워크는 합류 지점에서 두 텐서의 스케일/제로포인트 정렬이 관건입니다.
일반적으로는 QuantStub와 DeQuantStub을 모델 입출력에 배치하고, 내부는 FX 그래프 모드가 더 안정적인 편입니다. 다만 eager mode를 쓴다면 스텁 위치가 어긋나 float와 quantized가 섞여 예상치 못한 디퀀트가 발생할 수 있습니다.
간단한 스텁 예시:
import torch
import torch.nn as nn
import torch.ao.quantization as tq
class QuantWrapper(nn.Module):
def __init__(self, core):
super().__init__()
self.quant = tq.QuantStub()
self.core = core
self.dequant = tq.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.core(x)
x = self.dequant(x)
return x
model = QuantWrapper(model_fp32)
증상 기반 힌트:
FakeQuant는 괜찮은데INT8만 급락한다면, 합류 지점에서의 quantize/dequantize 삽입이 달라졌을 가능성이 있습니다.
7) 백엔드/커널 문제: fbgemm vs qnnpack 불일치
QAT는 “학습 때 가정한 양자화 스킴”과 “실제 실행 커널”이 맞아야 합니다.
- 서버 x86: 보통
fbgemm - 모바일 ARM: 보통
qnnpack
학습/변환/서빙 환경이 다르면 정확도뿐 아니라 출력 자체가 달라 보일 수 있습니다.
import torch
torch.backends.quantized.engine = "fbgemm" # 또는 "qnnpack"
print(torch.backends.quantized.engine)
운영에서 환경 차이로 문제가 생기는 패턴은 다른 분야에서도 흔합니다. 예를 들어 EKS에서 노드/보안 설정 차이로 특정 기능만 실패하는 경우처럼요. 비슷한 트러블슈팅 접근은 EKS에서 kubectl exec·logs가 안 될 때 진단법 같은 글의 “환경-권한-경로를 나눠서 확인”하는 방식이 참고가 됩니다.
8) 데이터 전처리 불일치: INT8에서만 더 크게 터진다
QAT 자체보다 더 자주 놓치는 것이 전처리/정규화 불일치입니다.
- FP32 학습/평가에서는 약간의 전처리 차이가 티가 덜 나는데
- INT8에서는 activation 범위가 제한되면서 작은 불일치가 더 큰 손실로 증폭됩니다.
체크리스트:
- 정규화(mean/std)가 학습과 동일한가
- 입력 dtype이
float32로 들어가고 있는가 - 채널 순서(NCHW) 및 스케일(0-1, 0-255)이 동일한가
간단한 가드 코드:
def assert_input_ok(x):
assert x.dtype == torch.float32
assert x.ndim == 4
# 값 범위 점검(예: 0-1 가정)
assert x.min().item() >= -0.1 and x.max().item() <= 1.1
9) 디버깅 실전: 어디 레이어에서 오차가 폭발하는지 찾기
정확도 하락을 “감”으로 때려 맞추지 말고, 레이어별로 FP32 출력과 FakeQuant 또는 INT8 출력 차이를 측정해 병목을 찾는 게 빠릅니다.
훅 기반 간단 비교:
import torch
def collect_activations(model, layer_names):
acts = {}
hooks = []
name_to_module = dict(model.named_modules())
def make_hook(name):
def hook(m, inp, out):
# out이 tuple일 수 있어 방어
t = out[0] if isinstance(out, (tuple, list)) else out
acts[name] = t.detach().float().cpu()
return hook
for n in layer_names:
hooks.append(name_to_module[n].register_forward_hook(make_hook(n)))
return acts, hooks
@torch.no_grad()
def compare_one_batch(model_a, model_b, x, layer_names):
acts_a, hooks_a = collect_activations(model_a, layer_names)
acts_b, hooks_b = collect_activations(model_b, layer_names)
_ = model_a(x)
_ = model_b(x)
for h in hooks_a + hooks_b:
h.remove()
diffs = {}
for n in layer_names:
a = acts_a[n]
b = acts_b[n]
diffs[n] = (a - b).abs().mean().item()
return diffs
# 사용 예: 중요한 블록 몇 개만 찍어도 방향이 보임
layer_names = [
"core.layer1.0.conv1",
"core.layer2.0.conv1",
"core.layer4.1.conv2",
]
diffs = compare_one_batch(model_fp32.eval(), model_prepared.eval(), x_batch, layer_names)
print(diffs)
이렇게 “오차가 급증하는 지점”을 찾은 뒤, 그 블록만 qconfig=None으로 제외하거나 observer를 바꾸는 식으로 빠르게 수렴시킬 수 있습니다.
10) 흔한 처방전(우선순위 순)
정확도 하락을 만났을 때, 아래 순서로 적용하면 시행착오가 줄어듭니다.
- 측정 분리:
FP32vsFakeQuantvsINT8 - 파이프라인 순서 고정:
fuse후prepare_qat, 그리고 변환 전eval - 백엔드 통일:
torch.backends.quantized.engine확인 - 옵저버 교체:
MinMax에서Histogram또는MovingAverage로 - freeze 타이밍 늦추기: 후반부에 observer/fakequant 고정
- 민감 레이어 제외: 마지막 FC, 첫 Conv, 특정 블록을 FP32 유지
- 전처리 재검증: 정규화/스케일/채널 순서
- 레이어별 오차 추적: 훅으로 폭발 지점 찾기
마무리
PyTorch QAT에서 INT8 변환 후 정확도 하락은 “QAT가 원래 어렵다”기보다, 대부분 설정과 상태의 불일치, 그리고 옵저버 통계의 부정확에서 시작합니다.
핵심은 한 번에 여러 가설을 건드리지 말고, FakeQuant와 INT8의 차이를 분리해 관측한 뒤, 옵저버/백엔드/민감 레이어를 작게, 반복적으로 조정하는 것입니다.
다음 단계로는 모델이 CNN인지 Transformer인지, 그리고 목표 런타임이 torchscript인지 onnxruntime인지에 따라 최적 처방이 달라집니다. 원하시면 모델 종류와 현재 qconfig, 정확도 하락 폭(top1, mAP 등), 그리고 FakeQuant 단계 정확도를 기준으로 더 구체적인 튜닝 플랜까지 정리해드릴 수 있습니다.