- Published on
PyTorch PTQ 튕김? FX Graph Mode QAT로 해결
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙에서 int8 양자화를 붙이려다 PTQ(Post-Training Quantization) 단계에서 갑자기 튕기거나(런타임 에러), 혹은 튕기진 않는데 정확도가 급락하는 경험은 흔합니다. 특히 비전/추천/트랜스포머 계열 모델에서 연산 그래프가 복잡해질수록, eager 모드 기반 PTQ는 관찰자(observer) 삽입 위치가 애매해지고, 지원되지 않는 연산 패턴이 섞이며, 백엔드별 제약이 겹쳐 “한 번은 되다가 다음 커밋에서 깨지는” 상황이 나옵니다.
이 글은 “PTQ가 튕김”을 단순히 에러 로그로 땜질하지 않고, FX Graph Mode QAT(Quantization-Aware Training) 로 전환해 문제를 구조적으로 해결하는 방법을 다룹니다. 핵심은 다음입니다.
- PTQ가 깨지는 원인을 그래프 관점에서 분해한다
- FX로 모델을 캡처하고, 준비(
prepare_qat_fx)와 변환(convert_fx) 을 일관되게 수행한다 - 대표적으로 많이 깨지는 패턴(Conv+BN+ReLU, residual add, cat, attention, layernorm 등)을 QAT 친화적으로 정리한다
- 학습/검증/내보내기까지 “재현 가능한 파이프라인”으로 만든다
문제 원인 분석 방식은 디버깅 체크리스트를 만드는 관점과 유사합니다. CI에서 캐시가 안 먹는 원인을 체크리스트로 줄이듯, 양자화도 “깨지는 지점”을 체계적으로 분류해야 속도가 납니다. 참고로 비슷한 디버깅 접근은 GitHub Actions 캐시가 안 먹을 때 디버깅 체크리스트 글의 사고방식이 도움이 됩니다.
PTQ가 튕기는 대표 원인 6가지
PTQ는 학습 없이(또는 아주 제한된 캘리브레이션만으로) int8 스케일/제로포인트를 잡습니다. 이때 아래 이슈가 자주 터집니다.
1) 지원되지 않는 연산/패턴이 그래프에 섞임
예: 동적 shape 기반 분기, 특수 커스텀 op, 일부 activation 조합, einsum 기반 attention, F.interpolate 같은 연산.
PTQ는 “관찰자 삽입 + 정적 변환”이 핵심인데, 변환 가능한 패턴이 아니면 변환 단계에서 실패하거나 런타임에서 튕깁니다.
2) 관찰자 삽입 위치가 잘못되어 스케일이 망가짐
PTQ는 캘리브레이션 데이터가 적으면 특히 취약합니다. outlier가 섞이면 activation 스케일이 커져서 유효 비트가 줄고, 정확도가 급락합니다.
3) 백엔드 제약(qnnpack, fbgemm)과 dtype 혼합
서버 x86에서 fbgemm, 모바일 arm에서 qnnpack을 쓰는데, 같은 모델이라도 허용되는 quantized op 조합이 다릅니다. 또한 중간에 fp16이나 bf16이 섞이면 변환이 깨질 수 있습니다.
4) add, cat 같은 텐서 결합 연산에서 스케일 정렬 실패
Residual add는 양자화에서 까다롭습니다. 두 텐서의 스케일이 다르면 내부적으로 requantize가 필요한데, eager PTQ에서는 이 경계가 불안정해지기 쉽습니다.
5) BN folding/ fusion 타이밍 문제
Conv-BN-ReLU 같은 패턴은 “fusion”이 잘 되면 안정적이지만, 모델 구조나 학습 상태(특히 BN의 running stat) 때문에 folding이 꼬이면 PTQ에서 정확도가 크게 흔들립니다.
6) 관찰자/가짜양자화(fake quant) 설정 미스
예: per-tensor vs per-channel, activation observer 종류(minmax vs histogram), weight observer 설정이 백엔드와 안 맞는 경우.
이런 문제는 “PTQ 파라미터를 조금씩 바꿔보는” 방식으로는 수렴이 느립니다. 근본적으로는 QAT로 학습 과정에서 양자화 오차를 모델이 흡수하도록 만들어야 합니다.
왜 FX Graph Mode QAT인가
PyTorch 양자화는 크게 eager 모드와 FX Graph Mode로 나뉩니다.
- eager: 모듈에 직접
qconfig를 달고, 수동으로 fusion/prepare/convert를 수행 - FX Graph Mode:
torch.fx로 그래프를 캡처한 뒤, 정해진 규칙에 따라 변환
FX Graph Mode QAT의 장점은 다음입니다.
- 그래프 기반으로 변환이 일관됨: fusion/observer 삽입이 규칙 기반이라 재현성이 높습니다.
- 디버깅이 쉬움: FX graph를 출력해서 “어디서 quantize/dequantize가 삽입됐는지” 확인 가능합니다.
- 패턴 매칭이 강력: Conv-BN-ReLU 등 대표 패턴을 안정적으로 처리합니다.
- QAT로 정확도 방어: PTQ에서 스케일이 망가지는 문제를 학습으로 흡수합니다.
정리하면, PTQ가 “한 방에” 성공하기 어려운 모델일수록 FX QAT가 실전적입니다.
실전 파이프라인: FX Graph Mode QAT
아래 예시는 fbgemm 백엔드(서버 x86) 기준입니다. 모바일 arm이면 qnnpack으로 바꾸면 됩니다.
1) 준비: 모델을 eval로 안정화하고, fusion 가능한 구조로 정리
- QAT 준비 전에는 BN이 있는 모델을 그대로 쓰되, 학습 루프에서 BN 동작을 어떻게 할지 결정해야 합니다.
- 일반적으로 QAT에서는 BN을 학습 초반에만 업데이트하고, 이후 freeze하는 전략을 많이 씁니다.
2) QAT 준비 코드
import torch
import torch.nn as nn
import torch.ao.quantization as aq
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
def build_qat_model(float_model: nn.Module, example_inputs):
# 1) 백엔드 선택
torch.backends.quantized.engine = "fbgemm"
# 2) QConfigMapping 설정
# 기본은 get_default_qat_qconfig("fbgemm")를 많이 사용
qconfig = aq.get_default_qat_qconfig("fbgemm")
qconfig_mapping = aq.QConfigMapping().set_global(qconfig)
# 3) FX 그래프 모드에서는 prepare 시점에 example_inputs가 중요
float_model.eval()
# 4) prepare_qat_fx: fake quant 모듈 삽입
qat_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs)
return qat_model
여기서 example_inputs는 단순 샘플이 아니라, 실제 서빙 입력 shape 를 대표해야 합니다. shape가 달라지면 trace/graph가 달라져 변환이 깨질 수 있습니다(특히 분기나 reshape가 많을 때).
3) QAT 학습 루프(핵심 포인트만)
import torch.optim as optim
def train_qat(qat_model, dataloader, num_epochs=3, lr=1e-4, device="cuda"):
qat_model.train()
qat_model.to(device)
optimizer = optim.AdamW(qat_model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
for x, y in dataloader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad(set_to_none=True)
logits = qat_model(x)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
# QAT에서는 epoch마다 eval로 성능 확인 권장
qat_model.eval()
# ... validation ...
qat_model.train()
return qat_model
QAT는 일반 fine-tuning보다 불안정할 수 있어, 다음을 권합니다.
- 학습률을 낮게 시작(
1e-4또는 더 낮게) - early stopping 기준을 명확히
- 캘리브레이션/학습 데이터 분포가 실제 트래픽과 유사해야 함
4) 변환: convert_fx로 진짜 int8 모델 생성
def convert_to_int8(qat_model, device="cpu"):
qat_model.eval()
qat_model.to(device)
# convert_fx가 quantized op로 치환
int8_model = convert_fx(qat_model)
return int8_model
주의: 변환 이후에는 많은 op가 quantized:: 계열로 바뀌고, 디바이스도 보통 cpu에서 실행하는 경우가 많습니다. (서버에서 fbgemm은 CPU int8 최적화가 핵심)
“PTQ 튕김”을 FX QAT로 해결하는 디버깅 전략
QAT로 전환했다고 무조건 해결되진 않습니다. 다만 “어디가 문제인지”를 훨씬 빨리 찾을 수 있습니다.
1) FX 그래프 출력으로 quantize 경계 확인
from torch.fx import GraphModule
def print_fx_graph(model: GraphModule):
print(model.graph)
그래프를 보면 activation_post_process(observer/fake quant) 같은 노드가 어디에 붙는지 확인할 수 있습니다. PTQ에서 튕길 때는 대개 “어떤 노드 앞뒤로 quantize/dequantize가 과도하게 삽입”되거나 “지원되지 않는 노드가 quantized 영역에 들어간” 경우가 많습니다.
2) 문제 연산을 float로 강제(Selective Quantization)
FX에서는 특정 모듈을 quantize 대상에서 제외하는 전략이 유효합니다. 예를 들어 LayerNorm이나 특정 attention 블록은 float로 남겨두고, Conv/Linear 중심으로만 int8을 적용하는 식입니다.
import torch.ao.quantization as aq
qconfig = aq.get_default_qat_qconfig("fbgemm")
qconfig_mapping = (
aq.QConfigMapping()
.set_global(qconfig)
.set_module_name("encoder.layernorm", None) # 예: 이 모듈은 float 유지
)
모듈 이름은 실제 모델 구조에 맞게 조정해야 합니다. 이 방식은 “완전 int8”은 아니지만, PTQ에서 튕기는 지점을 우회하면서도 성능 이득을 상당 부분 가져갈 수 있습니다.
3) Residual add가 많은 모델에서의 체크 포인트
- add 양쪽 텐서가 같은 quantization scheme을 갖는지
- add 직전에 불필요한 dequantize가 끼어 있지 않은지
- 가능하면 블록 단위로 quantize 영역을 설계(너무 촘촘히 끊지 않기)
4) 정확도 급락 시: observer 설정을 바꾸기 전에 데이터부터 점검
PTQ/ QAT 모두 “대표 데이터”가 핵심입니다.
- 캘리브레이션/학습 데이터가 너무 깨끗하면(현실 outlier가 없음) 실서빙에서 깨짐
- 반대로 outlier가 과도하면 스케일이 커져 정밀도가 떨어짐
이건 성능 문제의 원인을 추적하는 방식과 닮았습니다. 예를 들어 INP 스파이크를 Long Task로 추적하듯, 양자화도 “어떤 입력이 스케일을 망가뜨리는지”를 추적해야 합니다. 참고로 관측/추적 관점은 Chrome INP 급등? Long Task 추적·해결 가이드 같은 글의 접근이 유사합니다.
자주 마주치는 에러/증상과 처방
증상 A: convert_fx에서 변환 중 예외 발생
- 원인: 변환 규칙이 없는 op가 quantized 영역에 포함
- 처방: 해당 모듈을 float로 제외하거나, 모델을 단순화(예: 특정 블록을 scriptable한 형태로 변경)
증상 B: 변환은 되는데 실행 시 특정 입력에서만 튕김
- 원인: 동적 shape 또는 분기 조건이 example input과 다름
- 처방:
example_inputs를 실제 입력 분포를 커버하도록 구성(여러 shape를 지원해야 하면 설계를 재검토)
증상 C: 정확도 급락(PTQ 대비 QAT에서도 회복이 안 됨)
- 원인: activation quant가 너무 공격적이거나, 특정 레이어가 quant에 부적합
- 처방: selective quantization으로 핵심 레이어만
int8적용, per-channel weight quant 유지, QAT epoch 증가 및 LR 감소
배포 관점 체크리스트
QAT로 int8 모델을 만들었으면, 배포에서 다시 튕기지 않게 아래를 고정하세요.
torch버전,torchvision버전, quantization backend(fbgemm/qnnpack) 명시- 변환된 모델의 입출력 dtype/shape 계약 문서화
- 성능 측정 시 warm-up 포함, 스레드 수(
torch.set_num_threads) 고정 - CI에서 변환 파이프라인을 스모크 테스트로 고정
실무에서는 “학습 코드는 되는데 CI/CD에서만 깨짐”도 흔합니다. 그럴 때는 환경/캐시/아티팩트 문제까지 포함해 재현성을 잡아야 합니다. 이 관점에서 GitHub Actions 캐시가 안 먹을 때 디버깅 체크리스트 같은 글이 도움이 됩니다.
마무리: PTQ가 튕기면, QAT는 ‘우회’가 아니라 ‘정공법’
PTQ는 빠르고 단순하지만, 모델이 복잡해질수록 “깨질 확률”과 “정확도 리스크”가 커집니다. 반면 FX Graph Mode QAT 는
- 그래프 기반 변환으로 재현성을 확보하고
- fake quant를 학습에 포함해 양자화 오차를 모델이 흡수하게 만들어
- 결과적으로 PTQ에서 튕기던 케이스를 안정적으로
int8서빙까지 연결
할 수 있습니다.
다음 단계로는,
- 어떤 레이어를 float로 남길지(성능 vs 안정성)
- 백엔드별(
fbgemm/qnnpack) 성능 차이 - 모델 저장 및 로딩(예:
torch.export,torchscript)과의 궁합
까지 포함한 “양자화 배포 표준”을 팀 내에 만드는 것을 권합니다. QAT는 한 번 파이프라인을 잡아두면, 이후 모델 버전업 때 PTQ 삽질을 크게 줄여줍니다.