- Published on
PyTorch 2.2 PT2E INT8 양자화 오류 해결 가이드
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙/엣지 환경에서 INT8 양자화는 여전히 가장 강력한 비용 절감 카드입니다. 문제는 PyTorch 2.2부터 권장 경로가 PT2E(Export 기반 양자화 파이프라인)로 이동하면서, 기존 FX Graph Mode 양자화에서 잘 되던 모델이 export 단계에서 깨지거나, prepare_pt2e 이후 convert_pt2e에서 변환이 실패하거나, 변환은 되는데 실행 시 backend에서 런타임 에러가 나는 경우가 많다는 점입니다.
이 글은 PyTorch 2.2 PT2E로 INT8 양자화할 때 흔히 마주치는 오류를 증상별로 나누고, 원인과 해결책을 재현 가능한 코드와 함께 정리합니다. TorchDynamo, torch.export, torch.ao.quantization의 상호작용을 이해하면 디버깅 시간이 크게 줄어듭니다.
관련해서 torch.compile 기반 최적화 흐름은 아래 글도 함께 보면 전체 그림이 잡힙니다.
PT2E INT8 양자화 파이프라인 한 장 요약
PT2E(Post Training 2 Export) 양자화의 핵심은 Export로 고정된 그래프를 만든 뒤, 그 그래프에 관측기(observer)를 삽입해 캘리브레이션하고, 마지막으로 INT8 연산으로 변환하는 것입니다.
일반적인 흐름은 다음 4단계입니다.
torch.export.export로ExportedProgram생성prepare_pt2e로 관측기 삽입- 캘리브레이션 데이터로 몇 번 실행
convert_pt2e로 INT8 변환
주의할 점은, PT2E는 동적 제어 흐름, 데이터 의존 shape, 일부 in-place, 일부 커스텀 op에 민감하고, backend(예: FBGEMM, XNNPACK) 제약까지 같이 고려해야 한다는 것입니다.
기준이 되는 “정상 동작” 코드 템플릿
먼저 최소한의 정상 템플릿을 잡아두면, 오류가 났을 때 어느 단계에서 문제인지 빠르게 분리할 수 있습니다.
import torch
import torch.nn as nn
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import get_symmetric_quantization_config
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(128, 256)
self.act = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
model = MLP().eval()
example = (torch.randn(1, 128),)
# 1) Export
exported = torch.export.export(model, example)
# 2) Quantizer 설정 (모바일/CPU에서 XNNPACK을 많이 사용)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config())
# 3) Prepare
prepared = prepare_pt2e(exported, quantizer)
# 4) Calibration
for _ in range(16):
prepared.module()(*example)
# 5) Convert
quantized = convert_pt2e(prepared)
# Run
out = quantized.module()(*example)
print(out.shape)
이 템플릿을 기준으로, 아래에서 설명하는 오류들은 대개 export 단계, prepare 단계, convert 단계, 또는 backend 실행 단계에서 터집니다.
오류 1: export 단계에서 Graph break / Unsupported
대표 증상
torch.export.export에서 예외 발생- 메시지에
Graph break,Unsupported,Data dependent,Guard등이 포함 - 흔히
if x.sum() > 0:같은 데이터 의존 분기, Python 리스트 조작,.item()사용이 원인
원인
PT2E는 Export로 정적 그래프를 만들어야 합니다. 즉, 실행 중 텐서 값에 따라 Python 레벨에서 분기하거나, 텐서를 Python 스칼라로 빼서 제어 흐름에 쓰면 export가 불가능해집니다.
특히 다음 패턴이 자주 문제를 만듭니다.
if tensor.item() > 0:for _ in range(tensor.shape[0]):처럼 런타임 shape에 의존- Python list append로 텐서를 모아
torch.cat하기
해결 1: 제어 흐름을 텐서 연산으로 치환
# 나쁜 예
if x.sum().item() > 0:
x = x * 2
else:
x = x / 2
# 좋은 예: torch.where로 텐서 연산화
cond = x.sum() > 0
x = torch.where(cond, x * 2, x / 2)
해결 2: torch._dynamo.explain로 정확한 break 지점 찾기
import torch
def run(m, inp):
return m(*inp)
explain = torch._dynamo.explain(run)(model, example)
print(explain)
explain 결과에서 graph break가 난 라인/함수를 먼저 제거하거나 텐서 연산으로 바꾸면 export 성공률이 크게 올라갑니다.
오류 2: prepare 이후 캘리브레이션 실행 시 dtype/shape mismatch
대표 증상
prepared.module()(*example)실행 중 에러expected ... got ...형태의 dtype mismatch- Conv/Linear 입력 채널 shape mismatch
원인
PT2E의 prepare_pt2e는 그래프에 관측기를 삽입하면서 일부 연산이 재배치되거나, 모델이 기대하는 입력 스펙이 더 엄격해지는 경우가 있습니다. 특히 다음이 흔합니다.
- 예제 입력이 실제 서빙 입력과 다름
- export 당시의 입력 shape와 캘리브레이션 입력 shape가 다름
- 모델 내부에서
view/reshape가 하드코딩되어 있어 다른 배치/시퀀스 길이를 못 받음
해결: export 입력과 캘리브레이션 입력을 동일 스펙으로 고정
- export에 쓰는
example을 “서빙에서 가장 흔한 shape”로 고정 - 캘리브레이션도 동일 shape로 반복
- shape 가변이 필요하면, 먼저 모델에서
reshape하드코딩을 제거하고flatten(start_dim=1)같이 안정적인 연산으로 대체
# 나쁜 예: 배치가 바뀌면 깨질 수 있음
x = x.view(1, -1)
# 좋은 예
x = x.flatten(start_dim=1)
오류 3: convert 단계에서 “양자화 패턴 매칭 실패”
대표 증상
convert_pt2e에서 에러 또는 변환 후 INT8 op가 거의 생성되지 않음- 결과 그래프를 보면 float 연산만 남아 있음
원인
대부분은 backend가 지원하는 패턴으로 그래프가 만들어지지 않았기 때문입니다.
예를 들어 XNNPACK/FBGEMM은 특정 형태의 Conv/Linear + ReLU 패턴을 선호합니다. 그런데 모델이 아래처럼 작성되어 있으면 패턴 매칭이 깨질 수 있습니다.
F.relu(x)대신 커스텀 활성화add/mul이 섞인 잔차 블록에서 fusion이 어렵게 구성- LayerNorm, GELU 중심 구조(Transformer 계열)는 INT8 경로가 제한적
해결 1: 가능한 표준 모듈로 표현
# 나쁜 예
x = torch.nn.functional.relu(x)
# 좋은 예
self.act = nn.ReLU()
x = self.act(x)
PT2E는 표준 모듈 형태에서 더 안정적으로 패턴을 인식합니다.
해결 2: 어떤 op가 양자화됐는지 확인
변환이 되었는지 확인하려면 그래프를 출력해 “quantize/dequantize” 또는 “int8 packed weight”류 노드가 있는지 봅니다.
print(quantized)
# 또는
print(quantized.graph_module.graph)
(환경에 따라 출력 형태가 다를 수 있으니, 핵심은 변환 후 그래프에 INT8 관련 노드가 늘었는지 확인하는 것입니다.)
오류 4: backend 런타임 에러 (FBGEMM/XNNPACK)
대표 증상
- 변환은 성공했는데 실행 시 에러
- 메시지에
FBGEMM또는XNNPACK이 언급 - CPU에서만 재현되거나, 특정 연산에서만 크래시
원인
INT8 실행은 backend 제약을 강하게 받습니다.
- FBGEMM은 서버 CPU에서 강하고, XNNPACK은 모바일/경량 CPU에 최적화
- 연산별로 지원 dtype/axis가 제한됨
- per-channel, per-tensor 등 quant scheme이 backend와 안 맞으면 런타임 에러가 날 수 있음
해결 1: backend에 맞는 Quantizer를 선택
- 모바일/엣지:
XNNPACKQuantizer - 서버 CPU: (환경에 따라) FBGEMM 기반 구성이 유리
PT2E에서는 “어떤 quantizer를 쓰느냐”가 곧 backend 정책입니다. XNNPACK을 타깃이면 XNNPACK quantizer를 쓰고, 서버 CPU라면 PyTorch 문서/릴리스 노트에 맞는 quantizer 구성을 확인해야 합니다.
해결 2: 문제 op를 float로 남기는 “부분 양자화”
현실적으로 Transformer 계열이나 특수 연산이 끼어 있으면 전체를 INT8로 만들기 어렵습니다. 이때는 문제 구간만 float로 두고, Conv/Linear 같은 큰 연산만 INT8로 만드는 전략이 효과적입니다.
PT2E에서는 보통 “global 설정 + 예외 처리” 방식으로 접근합니다. 프로젝트 상황마다 API가 조금씩 달라질 수 있으니, 핵심은 아래 두 가지입니다.
- 전역은 INT8
- 특정 모듈/연산은 quantization 제외
구현은 모델 구조에 맞춰 모듈 단위로 분기하는 것이 가장 안전합니다.
# 예시: 특정 서브모듈은 양자화 제외를 위해 forward에서 분리하거나,
# quantizer 설정에서 해당 모듈을 제외하도록 구성(프로젝트별 API 상이)
class Model(nn.Module):
def __init__(self, backbone, head):
super().__init__()
self.backbone = backbone # INT8 타깃
self.head = head # float 유지(예: LayerNorm-heavy)
def forward(self, x):
x = self.backbone(x)
x = self.head(x)
return x
오류 5: in-place 연산 때문에 export/quantization 실패
대표 증상
add_,relu_,mul_같은 in-place가 포함된 모델에서 export 실패- 또는 prepare/convert 이후 결과가 틀어짐
원인
in-place는 그래프 변환(관측기 삽입, op 교체)에서 aliasing 문제를 만들기 쉽습니다. 특히 residual connection에서 x += y 같은 코드는 export/quant에서 불안정 요소입니다.
해결: in-place 제거
# 나쁜 예
x.add_(res)
# 좋은 예
x = x + res
성능상 in-place를 썼더라도, INT8 양자화/컴파일 파이프라인에서는 안정성이 우선입니다.
오류 6: 캘리브레이션 품질 문제로 정확도가 급락
대표 증상
- 변환은 성공, 속도도 개선
- 그런데 정확도가 크게 떨어짐
원인
PTQ INT8의 성패는 관측기 통계에 달려 있습니다.
- 캘리브레이션 데이터가 너무 적음
- 실제 입력 분포를 반영하지 못함
- outlier가 많은 분포에서 symmetric quant만 쓰면 손실이 커짐
해결 1: 캘리브레이션 데이터 “분포 대표성” 확보
- 최소 수십~수백 배치 권장(모델/도메인에 따라 다름)
- 실제 서빙 트래픽 샘플을 섞기
- 전처리(정규화/리사이즈)가 서빙과 동일해야 함
해결 2: 레이어별 민감도 측정 후 부분 양자화
- Linear/Conv는 INT8 유지
- 정확도 민감한 블록(예: 첫 레이어, 마지막 로짓 레이어, 특정 attention 블록)은 float 유지
이 접근은 “속도/정확도 트레이드오프”를 가장 예측 가능하게 만듭니다.
실전 디버깅 체크리스트 (단계별)
1) export가 실패한다
.item()제거, 데이터 의존 분기 제거- in-place 제거
torch._dynamo.explain로 break 지점 특정
2) prepare 후 실행이 실패한다
- export 입력과 캘리브레이션 입력 shape/dtype 일치
- 하드코딩
view제거,flatten등으로 대체
3) convert가 실패하거나 INT8이 거의 안 생긴다
- 표준 모듈 패턴(Conv/Linear + ReLU)로 정리
- 커스텀 활성화/연산 최소화
- 그래프 출력으로 실제 INT8 변환 여부 확인
4) 실행은 되는데 backend에서 터진다
- 타깃 backend에 맞는 quantizer 사용
- 문제 op는 float 유지(부분 양자화)
5) 정확도가 떨어진다
- 캘리브레이션 데이터 품질/양 개선
- 민감 레이어 float 유지
마무리: PT2E INT8은 “그래프 친화적 모델”이 핵심
PyTorch 2.2의 PT2E INT8 양자화는 기존 FX 기반보다 더 “컴파일러 파이프라인”에 가깝습니다. 그래서 모델을 다음 원칙으로 정리하면 성공률이 크게 올라갑니다.
- Python 제어 흐름을 텐서 연산으로 바꾸기
- in-place 제거로 그래프 변환 안정화
- export/캘리브레이션 입력 스펙 고정
- backend가 좋아하는 패턴(표준 모듈) 유지
- 전체 INT8이 어렵다면 부분 양자화로 현실적인 타협
추론 성능 최적화 관점에서 torch.compile과의 조합도 자주 쓰이니, 전체 최적화 플로우는 다음 글을 같이 참고하면 좋습니다.
원하면, 여러분이 겪는 실제 에러 로그(에러 메시지, 모델 타입, 타깃 backend, CPU/GPU 환경, 예제 입력 shape)를 기준으로 어느 단계에서 깨지는지를 분류해서, 해당 케이스에 맞는 최소 수정 패치를 제안해드릴 수 있습니다.