- Published on
PyTorch QAT로 INT8 양자화 정확도 하락 막기
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서버 추론 비용을 줄이기 위해 INT8 양자화를 적용했는데 정확도가 크게 떨어진 경험이 있다면, 대부분은 캘리브레이션(PTQ)만으로는 모델이 감당 못 하는 양자화 노이즈가 원인입니다. 이때 QAT(Quantization Aware Training)는 학습 중에 양자화 효과를 모사(FakeQuant)해 가중치와 활성값이 INT8 제약을 견디도록 적응시키는 방식이라, 정확도 하락을 크게 줄일 수 있습니다.
이 글은 PyTorch의 eager mode QAT 흐름을 기준으로, 정확도 하락을 막는 실전 체크리스트와 코드 패턴을 정리합니다. (프로덕션 배포 파이프라인에서 자주 겪는 “원인은 여러 개인데 증상은 비슷한” 문제를 줄이는 데 초점을 둡니다.)
운영 환경에서 성능/비용 최적화는 모델만의 문제가 아니라 배포 파이프라인 이슈와도 맞물립니다. 예를 들어 CI에서 권한/토큰 문제로 배포가 꼬이면 실험 결과 재현이 어려워집니다. 필요하면 GitHub Actions OIDC로 AWS 자격증명 0초 발급 같은 글도 함께 참고하세요.
QAT가 정확도를 지키는 원리(PTQ와 차이)
- PTQ(Post-Training Quantization): 학습 완료 FP32 모델을 INT8로 바꾼 뒤, 작은 캘리브레이션 데이터로 스케일과 제로포인트를 추정합니다. 데이터 분포가 바뀌거나(outlier가 많거나) 레이어별 민감도가 크면 정확도 하락이 큽니다.
- QAT: 학습 중 forward에 FakeQuant를 삽입해, 역전파가 양자화 오차를 “본 것처럼” 파라미터를 업데이트합니다. 결과적으로
- 가중치 분포가 양자화 친화적으로 정리되고
- 활성값의 범위가 관측기(observer)가 잡기 쉬운 형태로 안정화됩니다.
정확도 하락을 막는 핵심은 결국 아래 3가지입니다.
- 올바른 fuse 패턴(Conv-BN-ReLU 등)
- 올바른 qconfig(backend, per-channel, observer 설정)
- 올바른 학습 스케줄(옵저버/페이크퀀트 enable 타이밍, LR, epoch)
사전 점검: QAT 전 정확도 하락을 키우는 흔한 원인
1) Conv-BN-ReLU를 fuse하지 않음
BN이 남아 있으면 양자화 과정에서 스케일이 꼬이기 쉽고, 연산 그래프도 비효율적입니다. QAT는 보통 fuse를 전제로 합니다.
2) 활성값 outlier가 많은데 MinMax observer를 그대로 사용
활성값에 outlier가 있으면 MinMax가 범위를 과하게 넓혀서, 대부분 값의 유효 비트가 줄어듭니다. 이때는 Histogram 기반 observer(또는 percentile 계열 전략)가 유리한 경우가 많습니다.
3) per-tensor weight quant를 사용
가중치는 per-channel이 정확도에 유리한 경우가 많습니다(특히 Conv). per-tensor는 채널별 분포 차이를 못 따라가서 손실이 커질 수 있습니다.
4) 학습 스케줄이 “그냥 몇 epoch 더” 수준
QAT는 옵저버 통계가 안정화되는 구간과, FakeQuant가 본격적으로 들어가서 모델이 적응하는 구간이 다릅니다. 이 타이밍을 제어하지 않으면 정확도 회복이 안 됩니다.
PyTorch eager mode QAT 기본 파이프라인
아래 코드는 전형적인 흐름입니다.
- FP32 모델 준비 및 fuse
- qconfig 세팅
prepare_qat- QAT 파인튜닝
convert로 실제 INT8 모델 생성
예시 모델(간단한 Conv 블록)
import torch
import torch.nn as nn
class SmallCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = self.features(x)
x = self.pool(x).flatten(1)
return self.fc(x)
1) Fuse: Conv-BN-ReLU 묶기
torch.ao.quantization.fuse_modules는 모듈 이름 경로를 문자열로 받습니다. nn.Sequential이면 인덱스 문자열을 사용합니다.
import torch.ao.quantization as aq
def fuse_model(model: nn.Module) -> nn.Module:
# features[0]=Conv, [1]=BN, [2]=ReLU / features[3]=Conv, [4]=BN, [5]=ReLU
aq.fuse_modules(model.features, [['0', '1', '2'], ['3', '4', '5']], inplace=True)
return model
fuse가 잘못되면 정확도 이전에 성능이 안 나오거나, convert 단계에서 기대한 INT8 커널로 안 떨어지는 일이 잦습니다.
2) Backend 선택과 qconfig
CPU INT8에서 보통 backend는 fbgemm(x86 서버) 또는 qnnpack(ARM 모바일)을 씁니다.
- x86 서버:
fbgemm권장 - ARM:
qnnpack권장
import torch
import torch.ao.quantization as aq
torch.backends.quantized.engine = 'fbgemm'
model = SmallCNN().eval()
model = fuse_model(model)
# 기본 qconfig (대부분의 시작점)
model.qconfig = aq.get_default_qat_qconfig('fbgemm')
정확도 하락이 크면 qconfig를 더 보수적으로 조정합니다. 예를 들어 활성값 observer를 histogram 기반으로 바꾸거나, weight per-channel을 강제하는 식입니다.
PyTorch 버전/백엔드에 따라 기본 qconfig 내부가 조금씩 다릅니다. “기본값이 항상 최선”이 아니라, 데이터 분포에 맞게 observer를 바꾸는 게 정확도 방어에 효과적입니다.
3) prepare_qat
model.train()
model_prepared = aq.prepare_qat(model, inplace=False)
이 시점부터 forward에는 FakeQuant 모듈이 삽입되고, observer가 통계를 수집합니다.
정확도 하락을 막는 학습 스케줄(핵심)
QAT에서 가장 흔한 실패는 “학습을 했는데도 정확도가 안 돌아오는” 케이스입니다. 그 이유는 대개 observer 통계가 불안정한 상태에서 FakeQuant를 너무 강하게 적용했기 때문입니다.
권장 스케줄(실전에서 많이 쓰는 형태)
- 워밍업: observer는 켜고 FakeQuant는 끈 상태로 몇 백 step 또는 1 epoch
- 적응: FakeQuant를 켜고 학습
- 후반: observer 업데이트를 멈추고(통계 고정) FakeQuant만 유지한 채 미세 조정
PyTorch QAT 모듈은 enable_observer, disable_observer, enable_fake_quant, disable_fake_quant를 제공합니다.
def set_qat_state(m: nn.Module, observer: bool, fake_quant: bool):
if observer:
m.apply(aq.enable_observer)
else:
m.apply(aq.disable_observer)
if fake_quant:
m.apply(aq.enable_fake_quant)
else:
m.apply(aq.disable_fake_quant)
예시 학습 루프(개념 코드):
import torch.nn.functional as F
def train_qat(model, train_loader, optimizer, device, epochs=10):
model.to(device)
for epoch in range(epochs):
model.train()
# 스케줄 예시
if epoch == 0:
# 통계 수집: observer on, fake_quant off
set_qat_state(model, observer=True, fake_quant=False)
elif epoch == 1:
# 본격 QAT: observer on, fake_quant on
set_qat_state(model, observer=True, fake_quant=True)
elif epoch >= 3:
# 후반 안정화: observer off, fake_quant on
set_qat_state(model, observer=False, fake_quant=True)
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad(set_to_none=True)
logits = model(x)
loss = F.cross_entropy(logits, y)
loss.backward()
optimizer.step()
이 스케줄이 좋은 이유는 다음과 같습니다.
- 초반에 observer가 데이터 분포를 충분히 본 뒤
- FakeQuant로 인한 노이즈가 들어오고
- 후반에는 observer 범위가 흔들리지 않게 고정해 수렴을 돕습니다.
러닝레이트(LR) 팁
- FP32에서 잘 학습된 모델을 QAT로 파인튜닝할 때는 보통 LR을 10분의 1에서 100분의 1로 낮추는 게 안정적입니다.
- 정확도 하락이 큰 경우, epoch을 늘리기보다 **LR 스케줄(코사인 디케이, 스텝 디케이)**을 조정하는 편이 효율적일 때가 많습니다.
Observer와 FakeQuant 설정으로 정확도 방어하기
1) per-channel weight quant를 우선 고려
특히 Conv 레이어에서 per-channel은 정확도에 큰 차이를 만들 수 있습니다.
기본 QAT qconfig가 이미 per-channel을 쓰는 경우가 많지만, 모델/버전에 따라 다를 수 있어 확인이 필요합니다.
확인 방법(간단):
for name, m in model_prepared.named_modules():
if 'weight_fake_quant' in name:
print(name, m)
break
2) 활성값 outlier가 많으면 histogram observer 고려
데이터 분포에 outlier가 많으면 MinMax는 취약합니다. 이때 histogram 기반이 더 나은 경우가 많습니다.
PyTorch에서 observer를 직접 구성해 qconfig로 넣을 수 있습니다.
from torch.ao.quantization import QConfig
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantize
act_fake_quant = FakeQuantize.with_args(
observer=HistogramObserver,
quant_min=0,
quant_max=255,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False,
)
wt_fake_quant = FakeQuantize.with_args(
observer=PerChannelMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
)
model.qconfig = QConfig(activation=act_fake_quant, weight=wt_fake_quant)
주의할 점:
- backend에 따라 지원/최적화되는 qscheme이 다를 수 있습니다.
- activation을 per-channel로 하고 싶어도, 많은 backend는 활성값 per-tensor를 전제로 최적화됩니다.
3) 첫 레이어/마지막 레이어는 FP32 유지가 유리할 때가 있음
정확도 민감도가 높은 구간(입력에 가까운 레이어, logits 직전 레이어)은 INT8로 바꿀 때 손실이 커질 수 있습니다.
완전한 “전 구간 INT8”이 목표가 아니라면, 일부 레이어를 FP32로 남기는 혼합 정밀도 전략이 실전에서 자주 통합니다.
eager mode에서 특정 모듈만 qconfig를 None으로 두면 해당 모듈은 양자화 대상에서 제외됩니다.
model = SmallCNN()
model = fuse_model(model)
# 마지막 fc는 FP32로 유지(예시)
model.fc.qconfig = None
# 나머지는 QAT
model.qconfig = aq.get_default_qat_qconfig('fbgemm')
model_prepared = aq.prepare_qat(model)
Convert 후 검증: “진짜 INT8로 떨어졌는지” 확인하기
학습이 잘 됐더라도 convert 결과가 기대와 다르면 정확도/성능이 모두 흔들립니다.
model_prepared.eval()
model_int8 = aq.convert(model_prepared, inplace=False)
print(model_int8)
여기서 확인할 포인트:
QuantizedConv2d,QuantizedLinear같은 모듈로 바뀌었는지- fuse된 블록이 적절히 양자화 연산으로 치환됐는지
또한 벤치마크 시에는 반드시 같은 전처리/후처리를 쓰고, 입력 dtype과 스케일이 일관적인지 확인해야 합니다.
정확도 하락 디버깅 체크리스트(우선순위)
1) 데이터 파이프라인이 FP32 학습 때와 동일한가
전처리(정규화, 리사이즈, 색공간)가 조금만 달라도 observer 통계가 달라져 성능이 흔들립니다. 특히 캘리브레이션/검증 데이터 분포가 다르면 QAT 효과가 반감됩니다.
데이터/로그가 깨져서 원인 파악이 어려운 경우도 많습니다. 파이프라인이 Parquet나 Arrow 기반이라면, UTF-8 문제로 일부 샘플이 누락되거나 치환되는 경우도 있으니 필요하면 PyArrow Invalid - UTF-8 디코딩 오류 해결 가이드도 함께 점검하세요.
2) fuse가 올바른가
Conv-BN-ReLU가 fuse되지 않으면 QAT가 제대로 수렴해도 convert 결과가 비정상적일 수 있습니다.
3) backend(engine)와 qconfig가 일치하는가
torch.backends.quantized.engine이 fbgemm인데 qconfig를 qnnpack 기준으로 잡는 식의 불일치는 피하세요.
4) observer 안정화 후 fake quant를 켰는가
앞서 제시한 스케줄대로 observer와 fake quant를 단계적으로 제어해 보세요.
5) per-channel weight quant를 적용했는가
Conv에서 per-channel 여부는 정확도에 큰 차이가 날 수 있습니다.
실전 팁: “정확도 하락 1퍼센트”를 줄이는 자잘하지만 큰 차이
- 배치 크기: QAT는 통계 수집과 노이즈 적응이 핵심이라, 너무 작은 배치는 observer 통계를 불안정하게 만들 수 있습니다.
- EMA나 weight decay: 과한 weight decay는 가중치 분포를 좁혀 양자화 친화적으로 보일 수 있지만, 표현력이 떨어져 정확도 손실이 날 수 있습니다. 기존 FP32 설정을 그대로 가져오지 말고 재탐색하세요.
- 검증은 반드시 convert된 INT8 모델로: prepared(QAT) 모델의 fake quant 정확도와 convert된 실제 INT8 정확도가 다를 수 있습니다.
- 레이어별 민감도 분석: 특정 블록만 INT8에서 급락한다면 그 블록만 FP32로 남기거나, 그 블록의 observer를 조정하는 식으로 빠르게 수렴시킬 수 있습니다.
마무리: QAT 성공의 최소 조건
INT8 QAT로 정확도 하락을 막으려면, 단순히 prepare_qat 후 epoch 몇 번 더 돌리는 것보다 아래 3가지를 먼저 고정하는 게 효과가 큽니다.
- Fuse를 제대로 한다(Conv-BN-ReLU)
- qconfig를 데이터 분포에 맞춘다(특히 activation observer, weight per-channel)
- observer와 fake quant 스케줄을 나눈다(초반 통계 수집, 중반 적응, 후반 고정)
이 3가지만 지켜도 “INT8로 바꾸면 정확도가 크게 떨어진다”는 문제의 상당 부분은 해결됩니다. 이후에는 모델/데이터 특성에 맞춰 일부 레이어 FP32 유지, observer 교체, 학습 스케줄 튜닝으로 마지막 정확도를 끌어올리면 됩니다.