Published on

PyTorch 2.x PT2E INT8 양자화 에러 해결 가이드

Authors

PyTorch 2.x에서 INT8 양자화를 하려다 보면, 예전 FX Graph Mode Quantization보다 PT2E(Export 기반 양자화) 쪽에서 에러 메시지가 더 낯설게 느껴질 때가 많습니다. 이유는 간단합니다. PT2E는 torch.export그래프를 먼저 “고정” 한 뒤, 그 위에서 관측자 삽입과 변환을 수행하기 때문에, 모델이 가진 동적 제어 흐름, dtype 혼합, 연산 패턴이 더 엄격하게 검증됩니다.

이 글은 “일단 돌려보면 터지는” PT2E INT8 양자화 에러를 원인별로 분류하고, 재현 가능한 최소 코드와 함께 해결책을 제시합니다. 특히 서버 환경에서 양자화 파이프라인이 메모리/리소스 이슈로도 자주 무너지는 편이라, 진단 팁도 같이 넣었습니다. (메모리 관련 트러블슈팅은 리눅스 OOM Killer 로그로 메모리 누수 추적하기, 쿠버네티스 환경이면 Kubernetes OOMKilled 진단과 메모리 누수 추적 실전도 함께 보시면 좋습니다.)

PT2E INT8 양자화 파이프라인 한 장 요약

PT2E에서 대표적으로 많이 쓰는 흐름은 아래입니다.

  1. torch.export.export로 모델을 Export(정적 그래프화)
  2. prepare_pt2e로 관측자(Observer) 삽입
  3. 캘리브레이션 데이터로 몇 step 실행
  4. convert_pt2e로 INT8 연산으로 변환
  5. (선택) torch.compile로 최적화

아래 코드는 “돌아가는 기준선”입니다. 에러가 나면 이 기준선과 비교하면서 무엇이 다른지 확인하세요.

import torch
import torch.nn as nn

# PT2E APIs
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.relu(self.conv(x))
        x = self.pool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

device = "cpu"
model = M().eval().to(device)
example = (torch.randn(1, 3, 224, 224, device=device),)

# 1) export
exported = torch.export.export(model, example)

# 2) quantizer 설정 (XNNPACK: CPU INT8에서 흔히 사용)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config())

# 3) prepare
prepared = prepare_pt2e(exported, quantizer)

# 4) calibration (몇 번만 돌려도 됨)
with torch.no_grad():
    for _ in range(8):
        prepared(torch.randn(1, 3, 224, 224, device=device))

# 5) convert
quantized = convert_pt2e(prepared)

# smoke test
with torch.no_grad():
    y = quantized(torch.randn(1, 3, 224, 224, device=device))
print(y.shape)

이 기준선이 돌아가는데도 여러분 모델에서만 문제가 생긴다면, 아래 섹션의 “증상-원인-해결” 패턴으로 좁혀갈 수 있습니다.

에러 유형 1: Export 단계에서 깨지는 경우

증상

  • torch.export에서 실패
  • 메시지 예시(환경에 따라 다름)
    • 인플레이스 연산/동적 제어 흐름/지원되지 않는 연산 관련 에러
    • 그래프가 고정되지 않는다는 류의 에러

대표 원인

  1. 데이터 의존 분기: if x.sum() > 0: 같은 패턴
  2. 가변 shape 기반 로직: 입력 크기에 따라 다른 모듈을 타거나, reshape가 조건부로 바뀜
  3. inplace 연산: relu_(), add_() 등이 그래프 내에서 문제를 유발

해결

  • 분기 제거 또는 torch.cond 계열로 치환(가능한 경우)
  • 인플레이스 연산을 아웃오브플레이스로 변경
  • Export 입력 예시(example_inputs)를 실제 배치/shape와 맞추고, shape 변형을 고정

예: 인플레이스 제거

# bad
x.relu_()

# good
x = torch.relu(x)

또 하나의 실전 팁은 Export 전에 forward를 최대한 단순화하는 것입니다. 예를 들어 학습용 로직(드롭아웃, stochastic depth, aux loss)을 eval()에서 완전히 꺼도 남는다면, 양자화용 forward를 별도로 두는 것도 방법입니다.

에러 유형 2: Prepare/Convert에서 “양자화 패턴 매칭 실패”

증상

  • prepare_pt2e 또는 convert_pt2e에서 실패
  • 혹은 변환은 되는데 실행 시 특정 op에서 런타임 에러
  • 흔한 상황: Conv/Linear는 되는데, 중간의 커스텀 연산이나 reshape/concat 주변에서 깨짐

대표 원인

  1. 지원되지 않는 연산이 INT8 경로에 섞임
  2. 패턴이 끊김: 예를 들어 Conv2d 다음에 바로 ReLU가 오지 않고, 중간에 permute, view, add 등이 끼어들어 백엔드가 기대하는 fusion 패턴이 깨짐
  3. 연산이 float로 강제 캐스팅됨: x.float() 같은 코드가 중간에 있으면 quant/dequant가 과도하게 삽입되거나 변환 실패로 이어질 수 있음

해결

  • “문제 구간만 FP16/FP32로 남기기” 전략을 쓰는 게 가장 빠릅니다.
  • XNNPACKQuantizer는 전역 설정 외에도, 연산/모듈 단위로 제외(혹은 포함) 전략을 세울 수 있습니다.

실전적으로는 아래 순서가 효율적입니다.

  1. 일단 전역 INT8을 시도
  2. 실패하면, 문제되는 모듈을 찾아 제외
  3. 제외한 상태로 전체 파이프라인을 통과시킨 뒤, 제외 범위를 최소화

모듈 분리 예시(양자화 친화적으로 블록화)

class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.conv(x))

블록화는 단순한 리팩토링처럼 보이지만, PT2E에서 패턴 매칭을 안정화시키는 데 꽤 효과가 있습니다.

에러 유형 3: dtype/스케일 관련 런타임 에러(특히 add, cat, residual)

증상

  • 변환은 되는데 실행 중 add/cat/residual에서 에러
  • 또는 결과가 비정상(정확도 급락, NaN)

대표 원인

  1. 서로 다른 quantization parameter를 가진 텐서를 더함
  2. residual 경로에서 한쪽은 quantized, 다른 쪽은 float로 남아 dtype이 섞임
  3. activation quantization이 비대칭/대칭 설정과 충돌하거나, 관측자 통계가 왜곡됨

해결

  • residual/add가 있는 블록은 입출력 quant/dequant 경계를 명확히 하거나, 해당 블록만 float로 유지
  • 캘리브레이션 데이터를 실제 분포에 가깝게 준비(랜덤 데이터로 캘리브레이션하면 스케일이 망가져 정확도/안정성이 크게 떨어질 수 있음)

캘리브레이션은 “몇 step만 돌리면 된다”가 맞지만, 그 몇 step이 실제 입력 분포를 반영해야 합니다.

에러 유형 4: torch.compile과 함께 쓸 때 깨지는 경우

증상

  • PT2E로 변환한 모델을 torch.compile 하자마자 실패
  • 혹은 컴파일은 되는데 실행 시 특정 backend에서 실패

대표 원인

  • 컴파일러 백엔드가 특정 quantized op를 아직 완전히 지원하지 않음
  • dynamic shape, guard, decomposition 과정에서 quantized graph가 예상과 다르게 변형됨

해결

  • 우선순위는 “양자화 성공”입니다. 즉,
    1. torch.compile 없이 양자화 모델이 정상 동작하는지 확인
    2. 그 다음 torch.compile(quantized)를 시도
    3. 실패하면 컴파일 옵션을 보수적으로 조정하거나, 컴파일을 포기하고 eager로 운영

컴파일을 붙일 때는 아래처럼 단계적으로 확인하세요.

# 1) eager 실행 확인
with torch.no_grad():
    _ = quantized(torch.randn(1, 3, 224, 224))

# 2) compile 시도
compiled = torch.compile(quantized, mode="max-autotune")
with torch.no_grad():
    _ = compiled(torch.randn(1, 3, 224, 224))

또한 디버깅 중에는 mode를 바꿔가며 재현성을 확보하는 게 좋습니다. 성능 최적화 모드일수록 그래프 변형이 커져서 원인 파악이 어려워질 수 있습니다.

에러 유형 5: 메모리/리소스 문제로 파이프라인이 중간에 죽음

증상

  • 캘리브레이션 중 프로세스가 죽거나, 컨테이너가 OOMKilled
  • 로그 없이 종료되는 것처럼 보임

대표 원인

  • Export/Prepare 단계에서 그래프와 메타데이터가 커짐
  • 캘리브레이션 배치를 과하게 크게 잡음
  • 같은 프로세스에서 여러 모델을 반복 양자화하며 메모리 누수처럼 누적

해결

  • 캘리브레이션 배치/스텝을 최소화(예: 배치 1, 스텝 8~32부터 시작)
  • 양자화 작업을 별도 프로세스로 분리(작업 단위로 프로세스 재시작)
  • OOM 로그를 먼저 확인해 “진짜 OOM인지”를 확정

리눅스/쿠버네티스에서 OOM 확인은 아래 글 흐름이 그대로 적용됩니다.

디버깅 체크리스트(재현-격리-수정)

1) “최소 재현”부터 만든다

  • 입력 shape 하나로 고정
  • eval() 강제
  • dropout/랜덤성 제거
torch.manual_seed(0)
model.eval()

2) Export 결과를 기준으로 본다

PT2E는 Export가 기준입니다. Export가 깨지면 그 아래는 모두 의미가 없습니다.

exported = torch.export.export(model, example)
print(type(exported))

3) 캘리브레이션을 현실 데이터로

랜덤 텐서로 캘리브레이션하면 activation range가 실제와 달라져, INT8에서 클리핑이 과해지거나 스케일이 비정상적으로 잡힐 수 있습니다.

with torch.no_grad():
    for batch in calib_loader:
        prepared(batch)

4) 문제 구간을 float로 남기는 전략을 빠르게 적용

“전체 INT8”이 목적이어도, 운영에서 중요한 건 성공적으로 돌아가는 모델입니다. 특히 attention, 특수 activation, 커스텀 op는 부분 float이 현실적인 타협점이 됩니다.

결론: PT2E INT8 에러는 ‘그래프 고정’ 관점으로 풀린다

PT2E 기반 INT8 양자화 에러는 대부분 아래 중 하나로 귀결됩니다.

  • Export 단계에서 그래프를 고정할 수 없다
  • 양자화 백엔드가 기대하는 패턴(Conv/Linear 주변)이 끊긴다
  • residual/add/cat에서 dtype과 스케일이 충돌한다
  • torch.compile이 quantized op를 완벽히 감당하지 못한다
  • 캘리브레이션/변환 과정에서 메모리가 터진다

해결의 핵심은 “한 번에 다 고치기”가 아니라, 기준선 코드로 성공 → 모델에 한 조각씩 적용 → 깨지는 지점을 격리하는 방식입니다. 특히 캘리브레이션 데이터를 현실적으로 준비하고, 문제가 되는 블록을 과감히 float로 남기는 것만으로도 실무에서는 대부분 일정 수준의 성능/안정성을 확보할 수 있습니다.

원하시면, 여러분이 겪는 실제 에러 로그(스택 트레이스)와 모델의 핵심 블록(특히 residual/add/cat 부분) 코드를 주시면, PT2E 관점에서 어떤 연산이 병목인지와 “부분 양자화 설계”를 더 구체적으로 같이 잡아드릴게요.