Published on

PyTorch 2.x PT2E INT8 양자화 실패 해결 가이드

Authors

서빙 비용을 낮추기 위해 INT8 양자화를 붙였는데, PyTorch 2.x에서 권장하는 PT2E(Prepare/Convert to Edge) 플로우로 넘어오면서 갑자기 convert_pt2e 단계에서 터지거나, 실행은 되는데 정확도가 크게 깨지는 사례가 많습니다. 문제는 대개 “양자화 자체”가 아니라 그래프 캡처(Export) 조건, 관측기(Observer) 배치, 지원되는 연산 패턴, 백엔드(x86 FBGEMM vs ARM QNNPACK/XNNPACK) 불일치 같은 경계에서 발생합니다.

이 글은 PT2E INT8 양자화가 실패할 때, 어디서부터 어떻게 확인하고 고칠지 재현 가능한 체크리스트 + 코드 중심으로 정리합니다.

PT2E INT8 양자화 파이프라인 한 장 요약

PT2E는 기존 FX Graph Mode Quantization과 결이 비슷하지만, 핵심은 torch.export 기반의 ExportedProgram을 기준으로 준비/변환이 이뤄진다는 점입니다.

  1. 모델을 eval()로 고정
  2. torch.export.export로 그래프 캡처
  3. prepare_pt2e로 관측기 삽입
  4. 대표 데이터로 캘리브레이션(몇 배치 forward)
  5. convert_pt2e로 INT8 연산으로 치환
  6. (선택) torch.compile 또는 런타임에서 실행

실패는 보통 2~5 사이에서 발생합니다.

실패 유형 1: Export 단계에서 그래프 캡처가 안 됨

증상

  • torch.export.export에서 에러
  • 동적 shape, 데이터 의존 제어 흐름, Python side branching 때문에 실패
  • 흔한 메시지: GuardOnDataDependentSymNode 류, 또는 export 불가 연산 경고

해결 전략

  1. 입력 shape를 고정하고 export를 먼저 성공시키세요.
  2. forward 내부의 Python 분기(예: if x.sum() ...)를 제거하거나 텐서 연산으로 바꾸세요.
  3. export가 어려운 부분은 모듈 경계로 분리하거나, 해당 서브모듈만 양자화 대상으로 삼는 것도 방법입니다.

예시 코드: export 최소 재현 템플릿

import torch
import torch.nn as nn

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.relu(self.conv(x))
        x = self.pool(x).flatten(1)
        return self.fc(x)

model = M().eval()
example = (torch.randn(1, 3, 224, 224),)

ep = torch.export.export(model, example)
print(ep)

export가 여기서부터 안 되면, 양자화로 넘어가기 전에 모델 구조를 먼저 정리해야 합니다.

실패 유형 2: 백엔드/엔진 불일치로 인한 변환 실패

증상

  • convert_pt2e에서 특정 quantized op를 못 찾거나, 실행 시 커널이 없다는 에러
  • x86 서버에서는 되는데 ARM 환경에서 깨짐(또는 반대)

체크 포인트

  • x86 CPU: 보통 fbgemm
  • ARM/mobile: 보통 qnnpack 또는 XNNPACK 계열

PyTorch의 전역 엔진 설정이 맞지 않으면 변환은 되더라도 실행 시점에 커널이 없어 실패할 수 있습니다.

엔진 확인/설정

import torch

print(torch.backends.quantized.supported_engines)
# 예: ['qnnpack', 'none', 'onednn', 'fbgemm']

# x86 서버에서 보통
torch.backends.quantized.engine = "fbgemm"

# ARM에서 보통
# torch.backends.quantized.engine = "qnnpack"

print("engine:", torch.backends.quantized.engine)

엔진을 바꿨다면 export부터 다시 하는 게 안전합니다. 그래프 변환 및 패턴 매칭이 엔진에 따라 달라지는 경우가 있습니다.

실패 유형 3: 관측기(Observer) 삽입은 됐는데 캘리브레이션이 무효

증상

  • 변환은 성공했는데 정확도 급락
  • activation scale이 비정상적으로 커지거나, 거의 0에 가까워짐
  • 대표 데이터가 실제 분포를 반영하지 못함

해결 전략

  1. 캘리브레이션 데이터는 “학습 데이터 일부” 또는 “실제 트래픽 샘플”로 구성하세요.
  2. 최소 수십~수백 배치 정도를 권장합니다(모델/도메인에 따라 다름).
  3. 전처리(정규화, resize, dtype)가 학습/서빙과 동일해야 합니다.

PT2E 준비/캘리브레이션/변환 예시

import torch
import torch.nn as nn

from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
    X86InductorQuantizer,
    get_default_x86_inductor_quantization_config,
)

# 1) 모델/예시 입력
model = M().eval()
example = (torch.randn(1, 3, 224, 224),)

# 2) export
ep = torch.export.export(model, example)

# 3) quantizer 설정 (x86 예시)
quantizer = X86InductorQuantizer()
quantizer.set_global(get_default_x86_inductor_quantization_config())

# 4) prepare
prepared = prepare_pt2e(ep, quantizer)

# 5) calibration
prepared.module()  # ExportedProgram 내부 모듈 접근용
with torch.no_grad():
    for _ in range(100):
        x = torch.randn(1, 3, 224, 224)
        prepared.module()(x)

# 6) convert
quantized_ep = convert_pt2e(prepared)

# 7) run
with torch.no_grad():
    y = quantized_ep.module()(*example)
print(y.shape)

위 코드는 환경에 따라 import 경로나 quantizer가 다를 수 있습니다. 핵심은 prepare 이후 캘리브레이션 forward가 실제로 수행되어 observer 통계가 채워졌는지입니다.

캘리브레이션이 실제로 돌았는지 확인하는 팁

  • prepare 직후에는 그래프에 observer가 삽입되어야 합니다.
  • 캘리브레이션 후에는 observer 내부의 min/max 또는 histogram 통계가 업데이트됩니다.

실전에서는 통계를 직접 출력하기가 번거로우니, 먼저 캘리브레이션 배치를 0으로 두고 변환했을 때충분히 돌렸을 때 정확도가 얼마나 차이 나는지부터 비교해도 진단에 도움이 됩니다.

실패 유형 4: 지원되지 않는 연산 패턴 때문에 변환이 중단됨

증상

  • convert_pt2e에서 특정 op가 quantize/dequantize 사이에 남아 실패
  • “이 연산은 INT8로 내릴 수 없다”에 가까운 상황

자주 걸리는 패턴

  • LayerNorm, Softmax, 일부 GELU 변형
  • attention 블록의 특정 합성 패턴
  • 커스텀 op 또는 torch.where 같은 조건 연산이 섞인 경로

해결 전략

  1. 해당 블록을 FP16/FP32로 유지하고 나머지만 INT8로 내리기
  2. 모델을 약간 리팩터링해서 quantizable 패턴(Conv+ReLU, Linear+ReLU 등)으로 유도
  3. 가능하면 PyTorch가 제공하는 fused 패턴을 활용

특정 모듈만 양자화 제외(개념 예시)

PT2E에서는 “모듈 단위 제외”가 기존 eager 방식처럼 단순하진 않지만, 현실적으로는 다음 접근이 많이 쓰입니다.

  • 양자화 대상 서브모듈만 별도 모델로 분리
  • export/prepare/convert를 해당 서브모듈에만 적용
  • 나머지는 FP로 연결

이 방식은 디버깅이 쉽고, 실패 지점을 빠르게 격리할 수 있습니다.

실패 유형 5: 동적 shape 또는 배치 크기 변화로 런타임 실패

증상

  • 캘리브레이션은 batch=1로 했는데, 서빙에서 batch가 바뀌면 성능/정확도/실행이 흔들림
  • export 시점의 입력 제약(guards) 때문에 다른 shape에서 실행 불가

해결 전략

  1. 서빙에서 사용될 대표 shape 세트를 확정하고, 그에 맞춰 export를 수행
  2. 정말 동적 shape가 필요하다면, export에서 dynamic shapes 옵션을 검토하되(지원 범위 확인), 우선은 고정 shape로 성공시키는 게 우선입니다.

디버깅 루틴: 어디서 실패했는지 10분 안에 좁히기

  1. torch.export.export가 성공하는가
  2. prepare_pt2e 이후 prepared 모델이 forward 되는가
  3. 캘리브레이션을 여러 배치 수행했는가
  4. convert_pt2e가 성공하는가
  5. 변환 후 출력이 FP32 대비 어느 정도 일치하는가

FP32 vs INT8 출력 비교(간단 스모크 테스트)

import torch

def compare(fp_model, int8_ep, x):
    fp_model.eval()
    with torch.no_grad():
        y_fp = fp_model(x)
        y_q = int8_ep.module()(x)
    diff = (y_fp - y_q).abs().mean().item()
    print("mean abs diff:", diff)

x = torch.randn(1, 3, 224, 224)
compare(model, quantized_ep, x)

diff가 비정상적으로 크면, 캘리브레이션 데이터/전처리/observer 통계부터 의심하는 게 확률이 높습니다.

운영 관점 팁: 실패를 ‘한 번에’ 잡으려 하지 말 것

PT2E INT8 양자화는 모델마다 “되는 패턴”과 “안 되는 패턴”이 비교적 명확합니다. 그래서 한 번에 엔드투엔드로 밀어붙이기보다, 실패를 작은 단위로 쪼개는 게 비용을 줄입니다. 이 접근은 인프라/런타임 문제를 디버깅할 때도 동일합니다. 예를 들어 캐시가 꼬인 빌드 문제를 한 번에 해결하려고 하기보다, 캐시 레이어를 분해해 원인을 좁히는 식입니다. 비슷한 문제 해결 방식은 Docker 빌드가 느릴 때 BuildKit 캐시 깨짐 복구에서도 유효합니다.

또한 양자화 실패를 “모델 문제”로만 보지 말고, 배포 환경의 제약(커널/엔진/CPU feature)과 관측/재시도 체계를 함께 보세요. 장애 대응 관점의 디버깅 루틴은 Kubernetes CrashLoopBackOff 원인별 로그·Probe·리소스 디버깅 글의 방식과도 통합니다.

체크리스트(실전용)

  • model.eval() 상태에서 export했는가
  • export 입력의 dtype/shape가 서빙과 일치하는가
  • torch.backends.quantized.engine이 환경에 맞는가
  • prepare_pt2e 이후 캘리브레이션 forward를 충분히 수행했는가
  • 캘리브레이션 데이터가 실제 분포를 반영하는가
  • 변환 실패 시, 문제 op가 포함된 블록을 분리해 최소 단위로 재현했는가
  • 변환 성공 후 FP32 대비 출력 차이를 스모크 테스트로 확인했는가

마무리

PyTorch 2.x에서 PT2E로 INT8 양자화를 적용할 때의 실패는, 대부분 “지원되지 않는 그래프/패턴” 또는 “캘리브레이션/엔진 설정의 불일치”로 귀결됩니다. 가장 빠른 해결법은 export 성공을 1차 목표로 두고, 그 다음에 prepare-캘리브레이션-convert를 단계별로 검증하는 것입니다. 이 루틴만 습관화해도, convert_pt2e에서 막혔을 때 원인을 추측이 아니라 근거로 좁힐 수 있습니다.