- Published on
PyTorch 2.0 PT2E+INT8 양자화 정확도 지키기
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서론
PyTorch 2.0부터 양자화 워크플로는 torch.export 중심으로 재편되었고, 그 흐름 위에서 동작하는 PT2E(Prepare/Convert Export) 방식이 점점 표준이 되고 있습니다. 기존 FX Graph Mode 양자화에 익숙한 팀이라면, “INT8로 바꾸면 속도는 빨라지는데 정확도가 갑자기 떨어진다”는 문제를 PT2E에서도 동일하게 겪습니다.
정확도 하락은 대개 “INT8 자체가 나빠서”가 아니라, (1) 캘리브레이션 데이터/절차, (2) 관측자(observer) 설정, (3) per-tensor vs per-channel, (4) 연산 그래프에서의 fusion/패턴 미스매치, (5) 특정 레이어에 대한 무리한 양자화 같은 구성 문제에서 발생합니다.
이 글은 PyTorch 2.0 환경에서 PT2E 기반 INT8 양자화를 적용할 때, 정확도를 최대한 유지하는 실전 방법을 체크리스트처럼 정리합니다. 모델은 CNN/Transformer 모두에 적용 가능한 원칙 위주로 설명하고, 코드 예제는 PT2E 스타일로 제공합니다.
PT2E 양자화 한 장 요약
PT2E는 크게 다음 단계로 생각하면 됩니다.
torch.export로 그래프를 “고정”한다(동적 제어 흐름/파이썬 로직을 최소화)prepare_pt2e로 관측자를 삽입한다- 대표 입력/캘리브레이션 데이터로 모델을 실행해 통계를 모은다
convert_pt2e로 실제 INT8 연산(또는 INT8 가중치/활성)을 사용하는 모델로 변환한다
정확도는 2~3 단계에서 대부분 결정됩니다. 관측자가 무엇을 어떻게 측정했는지, 캘리브레이션이 실제 서빙 분포를 얼마나 닮았는지가 핵심입니다.
정확도가 떨어지는 7가지 대표 원인
1) 캘리브레이션 데이터가 실제 입력 분포를 못 따라감
가장 흔한 원인입니다. 예를 들어 이미지 모델에서 캘리브레이션을 32장만 대충 돌리거나, 전처리(정규화/리사이즈/패딩)가 서빙과 다르면 activation range가 왜곡됩니다. 그러면 scale/zero-point가 엉뚱해지고, INT8에서 정보 손실이 커집니다.
권장:
- 최소 수백~수천 샘플(모델/도메인에 따라 다름)
- 전처리 파이프라인을 서빙과 1:1로 맞추기
- 클래스 편향이 심한 데이터만 쓰지 않기
2) per-tensor 양자화로 채널별 분포 차이를 무시함
Conv/Linear 가중치는 채널별 분포가 크게 다릅니다. per-tensor로 묶어버리면 outlier 채널 때문에 scale이 커져 대부분 채널의 유효 비트가 줄어듭니다.
권장:
- 가중치는 가능하면 per-channel(대개
qscheme이per_channel_symmetric)을 우선 고려
3) activation 관측자 선택이 모델 특성과 안 맞음
MinMax 기반 관측자는 outlier에 취약합니다. Transformer 계열에서 활성 값의 꼬리가 길면, MinMax는 scale을 과하게 키우고 양자화 노이즈가 커집니다.
권장:
- 분포가 두꺼운 모델은 histogram/percentile 계열 관측자(가능한 경우) 또는 clipping 전략 고려
- 레이어별로 관측자를 다르게 가져가는 것도 실전에서 효과가 큼
4) 연산 fusion이 기대대로 안 되어 스케일 전파가 불리해짐
Conv+BN+ReLU 같은 패턴이 잘 fuse되면 양자화 포인트가 줄고 오차가 줄어듭니다. 반대로 fuse가 안 되면 중간 activation이 불필요하게 양자화/역양자화되거나, 관측 포인트가 늘어 오차가 누적될 수 있습니다.
권장:
- export 후 그래프를 확인해 원하는 패턴이 남아 있는지 점검
- 모델을 “양자화 친화적” 구조로 미리 정리(불필요한 view/transpose 남발, 커스텀 op 난립 등 최소화)
5) 레이어별 민감도(quant sensitivity)를 무시함
모든 레이어를 동일 정책으로 INT8로 밀어붙이면 대개 마지막 분류기, 첫 번째 stem, attention의 특정 projection, layernorm 주변에서 정확도가 크게 흔들립니다.
권장:
- “부분 양자화” 전략: 민감한 블록은 FP16/FP32 유지
- 레이어별 에러를 측정해 예외 리스트를 만들기
6) 동적 범위가 큰 연산(Softmax, LayerNorm 등)을 무리하게 양자화
일부 연산은 INT8로 표현할 때 손실이 커지거나, 백엔드가 사실상 FP로 처리하면서 오히려 성능/정확도 모두 손해가 날 수 있습니다.
권장:
- Softmax/LayerNorm은 FP로 유지하는 경우가 많음(백엔드 지원과 모델에 따라 다름)
7) 평가 파이프라인 자체가 흔들림
정확도 비교를 할 때, FP32 기준 모델이 eval() 이 아니거나, dropout이 켜져 있거나, seed가 고정되지 않아 분산이 커지면 “양자화 때문에 떨어졌다”라고 오판하기 쉽습니다.
권장:
- FP32/INT8 모두 동일한 전처리, 동일한
model.eval() - 동일한 샘플셋으로 비교
PT2E INT8 양자화 기본 코드(골격)
아래 예시는 PT2E 스타일의 전체 흐름을 보여주는 뼈대입니다. 실제 API 이름/옵션은 PyTorch 마이너 버전에 따라 조금씩 달라질 수 있으니, 설치 버전 문서를 함께 확인하세요.
import torch
# PT2E 관련 모듈은 버전에 따라 경로가 다를 수 있습니다.
# 아래는 개념적 예시입니다.
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
def calibrate(model, dataloader, device="cpu", num_batches=50):
model.eval()
with torch.inference_mode():
for i, batch in enumerate(dataloader):
if i >= num_batches:
break
x = batch[0].to(device)
_ = model(x)
def build_qconfig_mapping():
# 가중치: per-channel, activation: per-tensor (기본 출발점)
act_obs = MinMaxObserver.with_args(dtype=torch.quint8, qscheme=torch.per_tensor_affine)
wt_obs = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
qconfig = torch.ao.quantization.QConfig(activation=act_obs, weight=wt_obs)
return QConfigMapping().set_global(qconfig)
def pt2e_int8_quantize(model_fp32, example_inputs, calib_loader):
model_fp32.eval()
# 1) export
exported = torch.export.export(model_fp32, example_inputs)
# 2) prepare (observer 삽입)
qconfig_mapping = build_qconfig_mapping()
prepared = prepare_pt2e(exported, qconfig_mapping=qconfig_mapping)
# 3) calibration
calibrate(prepared, calib_loader, device="cpu", num_batches=100)
# 4) convert
quantized = convert_pt2e(prepared)
return quantized
이 골격만으로도 “일단 돌아가는 INT8”은 만들 수 있습니다. 하지만 정확도 방어는 여기서부터 시작입니다.
정확도 방어 1: 캘리브레이션을 ‘서빙과 동일’하게 만들기
캘리브레이션은 관측자가 activation의 범위를 추정하는 과정입니다. 따라서 다음을 반드시 맞추세요.
- 전처리: 정규화(mean/std), 토큰화, padding, truncation, 이미지 리사이즈/센터크롭 정책
- 입력 길이 분포: LLM/Transformer는 시퀀스 길이가 range에 큰 영향
- 배치 크기: 보통은 큰 영향이 없지만, 일부 모델은 batchnorm/통계 경로 때문에 영향을 받을 수 있음
또한 “대표 입력”의 다양성이 중요합니다. 클래스가 100개면 최소한 클래스 분포가 어느 정도 섞인 샘플을 쓰는 편이 안정적입니다.
정확도 방어 2: per-channel 가중치는 거의 필수로 생각하기
Conv/Linear 가중치를 per-channel로 바꾸는 것만으로도 정확도가 눈에 띄게 회복되는 경우가 많습니다. 특히 depthwise conv나 채널별 스케일 편차가 큰 모델에서 효과가 큽니다.
위 코드의 PerChannelMinMaxObserver 설정이 바로 그 출발점입니다.
추가 팁:
- 가중치 관측자는 대개 symmetric이 유리합니다(제로포인트가 0 근처라 오차가 줄어듦)
- activation은 affine이 일반적(0을 정확히 포함하지 않는 분포가 많음)
정확도 방어 3: outlier 대응(클리핑/히스토그램 기반)
MinMax는 outlier 하나에 민감합니다. Transformer류에서 특정 토큰/문장 패턴이 outlier를 만들면, 전체 레이어의 scale이 커져 양자화 노이즈가 증가합니다.
가능한 선택지:
- histogram 기반 관측자
- percentile 기반 클리핑(예: 상위 99.9퍼센타일까지만 범위로 사용)
- 레이어별 activation을 분석해 outlier가 심한 레이어만 별도 정책 적용
PyTorch에서 어떤 관측자를 쓸 수 있는지는 버전과 백엔드에 따라 다르므로, “모든 레이어에 동일 observer”에서 벗어나 레이어별로 바꾸는 전략을 준비해 두는 것이 좋습니다.
정확도 방어 4: 레이어별 예외 처리(부분 양자화)
실무에서 가장 효과적인 방법 중 하나는 “정확도 민감 레이어는 FP로 남겨두는 것”입니다. 속도는 약간 손해 보지만, 정확도는 큰 폭으로 회복되곤 합니다.
대표적으로 예외 후보:
- 입력 stem(첫 conv/patch embedding)
- 출력 head(마지막 linear)
- attention의 특정 projection
- layernorm 주변
PT2E에서는 보통 QConfigMapping을 이용해 모듈/연산 패턴별로 qconfig를 다르게 주거나, 특정 모듈을 아예 양자화 대상에서 제외하는 형태로 접근합니다.
개념 예시:
from torch.ao.quantization import QConfigMapping
qconfig_mapping = (
QConfigMapping()
.set_global(qconfig_int8)
# 예: 마지막 분류기는 FP 유지
.set_module_name("classifier", None)
)
여기서 핵심은 “예외를 감으로 정하지 말고” 측정으로 정하는 것입니다.
정확도 방어 5: 민감도 측정(레이어별 에러)로 예외 리스트 만들기
정확도 하락이 크면, 다음과 같이 “레이어별로 FP 출력과 양자화 출력의 차이”를 로그로 남겨 민감 레이어를 찾는 방식이 효과적입니다.
간단한 접근:
- 동일 입력에 대해 FP32 모델과 prepared 모델(관측자 삽입 상태)을 각각 실행
- 특정 모듈의 출력 텐서를 hook으로 수집
MSE,cosine distance,max abs error등을 계산
예시 코드(아이디어용):
import torch
import torch.nn.functional as F
def collect_activations(model, x, module_names):
acts = {}
hooks = []
name_to_module = dict(model.named_modules())
def make_hook(name):
def hook(_m, _inp, out):
if torch.is_tensor(out):
acts[name] = out.detach().float().cpu()
return hook
for n in module_names:
m = name_to_module[n]
hooks.append(m.register_forward_hook(make_hook(n)))
model.eval()
with torch.inference_mode():
_ = model(x)
for h in hooks:
h.remove()
return acts
def compare(fp_acts, q_acts):
scores = {}
for k in fp_acts.keys():
a = fp_acts[k]
b = q_acts[k]
mse = F.mse_loss(a, b).item()
cos = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
scores[k] = {"mse": mse, "cos": cos}
return scores
이렇게 점수를 뽑아보면 “어느 레이어에서 오차가 폭발하는지”가 보이고, 그 레이어를 FP로 남기거나 관측자를 바꾸는 식으로 빠르게 수렴할 수 있습니다.
정확도 방어 6: 백엔드와 dtype 조합을 현실적으로 선택하기
INT8 양자화는 “어떤 커널이 실제로 INT8로 돌고 있는지”가 중요합니다. CPU에서는 보통 FBGEMM(서버) 또는 QNNPACK(모바일)이 관여하고, 지원 연산/패턴이 다릅니다. 지원이 약한 연산을 억지로 양자화하면 내부적으로 dequant-quant가 늘어나 정확도/성능이 같이 나빠질 수 있습니다.
권장:
- 목표 디바이스에서 프로파일링으로 실제 커널이 무엇인지 확인
- 지원이 약한 블록은 FP16/FP32 유지
정확도 방어 7: 평가/서빙 환경 불일치 제거
양자화 모델은 운영 파이프라인에서 “환경 차이”에 더 민감해지는 경향이 있습니다. 예를 들어 컨테이너 이미지에서 CPU feature flag(AVX2/AVX512), 스레드 설정, 라이브러리 버전이 달라져 결과가 미세하게 흔들릴 수 있습니다.
운영 이슈를 줄이는 관점에서, 배포 자동화 파이프라인에서 조건 분기나 매트릭스 설정이 꼬이면 재현이 어려워집니다. CI 설정이 복잡하다면 다음 글도 함께 참고해두면 좋습니다.
실전 체크리스트: “정확도 안 떨어지는” 쪽으로 수렴시키는 순서
- FP32 기준 정확도를 고정된 평가셋에서 재현 가능하게 만든다
- 캘리브레이션 데이터/전처리를 서빙과 동일하게 맞춘다
- 가중치는 per-channel symmetric으로 시작한다
- outlier가 의심되면 activation 관측자를 바꾸거나 클리핑 전략을 도입한다
- 레이어별 민감도를 측정해 예외 리스트를 만든다(부분 양자화)
- export 그래프에서 fusion/패턴이 기대대로인지 확인한다
- 목표 디바이스에서 실제 커널과 성능을 확인한다
결론
PyTorch 2.0의 PT2E 기반 INT8 양자화는 “그래프를 고정하고, 관측자를 삽입하고, 캘리브레이션으로 통계를 모아 변환한다”는 점에서 명확하고 강력합니다. 하지만 정확도 하락은 대부분 설정과 데이터 문제에서 시작합니다.
정확도를 지키려면 캘리브레이션을 서빙 분포에 맞추고, per-channel 가중치를 기본으로 두며, outlier 대응과 레이어별 예외 처리(부분 양자화)를 체계적으로 적용해야 합니다. 특히 레이어별 민감도 측정으로 예외 리스트를 만드는 접근은, 감에 의존하지 않고 빠르게 정확도를 복구하는 가장 현실적인 방법입니다.