Published on

PyTorch 2.x PTQ로 INT8 양자화 ONNXRT 튜닝

Authors

서빙 비용을 낮추거나 CPU 추론 지연을 줄이려면 INT8 양자화는 거의 필수 옵션입니다. 다만 PyTorch 2.x 환경에서는 torch.compile·torch.export·ONNX 변환·ONNX Runtime(이하 ONNXRT) 양자화까지 체인이 길어지면서, “INT8은 됐는데 정확도가 떨어짐”, “속도는 그대로임”, “특정 연산이 FP32로 남음” 같은 문제가 자주 발생합니다.

이 글은 PyTorch 2.x 모델을 PTQ(Post-Training Quantization)로 INT8로 내리고, ONNXRT에서 Q/DQ 기반으로 튜닝하는 과정을 실무 관점에서 정리합니다. 목표는 다음 3가지입니다.

  • 정확도 하락을 최소화하면서 INT8 적용 범위를 넓히기
  • CPU에서 실제 지연(latency)과 처리량(throughput) 개선을 얻기
  • “왜 느린지/왜 깨지는지”를 재현 가능하게 디버깅하기

배포 관점에서 카나리로 안전하게 교체하는 흐름은 KServe+Knative로 GPU 모델 무중단 카나리 배포도 함께 참고하면 좋습니다.

PTQ + ONNXRT INT8의 큰 그림

PTQ는 학습 없이(또는 미세한 보정만으로) 대표 데이터(캘리브레이션 세트) 를 흘려보내 activation의 통계를 수집한 뒤, 이를 기반으로 scale/zero-point를 정해 INT8로 양자화합니다.

ONNXRT에서 흔히 쓰는 방식은 Q/DQ(QuantizeLinear/DequantizeLinear) 노드 삽입입니다.

  • 장점: 원래 연산 그래프를 보존하면서, 어떤 구간이 INT8로 실행되는지 추적이 쉬움
  • 단점: EP(Execution Provider)와 커널 지원 여부에 따라 INT8로 “붙지” 않을 수 있음

핵심은 “ONNX로 내보낸 뒤 ONNXRT에서 양자화”가 아니라, 모델 구조/연산자/EP/캘리브레이션 품질을 함께 튜닝해야 한다는 점입니다.

준비: 버전/환경 체크리스트

버전 조합에 따라 결과가 크게 달라질 수 있어, 재현 가능한 환경을 먼저 고정하는 편이 좋습니다.

  • PyTorch 2.x (가능하면 최신 2.x)
  • onnx 최신
  • onnxruntime 또는 onnxruntime-gpu
  • CPU INT8 최적화 목적이면 onnxruntime + (가능 시) onnxruntime-extensions는 선택

설치 예시:

pip install -U torch torchvision torchaudio
pip install -U onnx onnxruntime
pip install -U onnxruntime-tools

CPU에서 INT8 성능을 제대로 보려면 스레드/바인딩/전력 모드도 영향을 줍니다. 최소한 다음은 실험마다 동일하게 맞추세요.

  • OMP_NUM_THREADS, MKL_NUM_THREADS
  • 프로세스 affinity(고정 가능하면 고정)
  • 동일 입력 배치/시퀀스 길이

1) 캘리브레이션 데이터: “정확도의 80%”

PTQ에서 가장 흔한 실패 원인은 캘리브레이션 데이터가 실제 트래픽 분포를 대표하지 못하는 것입니다.

캘리브레이션 세트 구성 가이드

  • 실제 서빙 입력 분포를 반영(길이, 해상도, 도메인, 전처리 포함)
  • 데이터 개수는 모델/도메인에 따라 다르지만, 경험적으로
    • CV 분류: 수백~수천 장
    • NLP: 수백~수천 문장(길이 다양하게)
    • 멀티모달/디텍션: 더 많이 필요할 수 있음
  • 극단값(outlier)이 너무 많으면 activation 범위가 커져 INT8 해상도가 떨어질 수 있으니, 클리핑/퍼센타일 기반 캘리브레이션을 고려

전처리 일치가 중요

PyTorch에서 학습/추론하던 전처리를 ONNXRT 캘리브레이션에도 동일하게 적용해야 합니다. 예를 들어 이미지라면 resize, normalize(mean/std), 채널 순서(NCHW)까지 정확히 동일해야 합니다.

2) PyTorch 2.x 모델을 ONNX로 내보내기

PyTorch 2.x에서는 torch.onnx.export도 가능하지만, 그래프 안정성을 위해 torch.export 기반 경로를 쓰는 경우가 늘었습니다. 여기서는 널리 쓰이는 torch.onnx.export 예시를 들되, 동적 축과 연산자 호환성에 주의합니다.

import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, in_dim=768, hidden=2048, out_dim=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, out_dim)
        )

    def forward(self, x):
        return self.net(x)

model = MLP().eval()
dummy = torch.randn(1, 128, 768)  # 예: (batch, seq, hidden)

torch.onnx.export(
    model,
    dummy,
    "model_fp32.onnx",
    opset_version=17,
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={
        "input": {0: "batch", 1: "seq"},
        "logits": {0: "batch", 1: "seq"}
    },
)

튜닝 관점에서 중요한 포인트:

  • opset_version는 너무 낮으면 Q/DQ 삽입 후 커널 지원이 애매해질 수 있음
  • 동적 축이 많을수록 일부 최적화가 제한될 수 있음(특히 CPU에서)
  • LayerNorm, GELU 같은 연산이 INT8로 완전히 내려가지 않는 경우가 흔함(혼합 정밀이 정상일 수 있음)

3) ONNXRT PTQ: 정적(static) 양자화로 Q/DQ 삽입

ONNXRT의 PTQ는 크게

  • Dynamic quantization: 주로 weight만 INT8, activation은 런타임에 처리(간편하지만 성능/정확도 한계)
  • Static quantization: 캘리브레이션으로 activation 통계를 잡아 Q/DQ 삽입(보통 더 좋은 성능/정확도)

여기서는 Static을 기준으로 설명합니다.

캘리브레이션 데이터 리더 예시

ONNXRT quantization API는 캘리브레이션 데이터를 iterator로 받습니다.

import numpy as np
from onnxruntime.quantization import CalibrationDataReader

class NumpyCalibrationDataReader(CalibrationDataReader):
    def __init__(self, np_arrays):
        self.data_iter = iter([
            {"input": arr.astype(np.float32)} for arr in np_arrays
        ])

    def get_next(self):
        return next(self.data_iter, None)

Static quantization 실행

from onnxruntime.quantization import (
    quantize_static,
    QuantFormat,
    QuantType,
    CalibrationMethod,
)

fp32_path = "model_fp32.onnx"
int8_path = "model_int8_qdq.onnx"

# 예시 캘리브레이션 입력 (실무에서는 실제 데이터로 구성)
calib_inputs = [np.random.randn(1, 128, 768) for _ in range(200)]
reader = NumpyCalibrationDataReader(calib_inputs)

quantize_static(
    model_input=fp32_path,
    model_output=int8_path,
    calibration_data_reader=reader,
    quant_format=QuantFormat.QDQ,
    activation_type=QuantType.QInt8,
    weight_type=QuantType.QInt8,
    calibrate_method=CalibrationMethod.MinMax,
)

여기서부터가 “튜닝”의 시작입니다.

4) 캘리브레이션 방법 튜닝: MinMax vs Entropy vs Percentile

정확도 하락이 크면, 가장 먼저 바꿔볼 축은 캘리브레이션 방식입니다.

  • MinMax: 구현이 단순하지만 outlier에 취약
  • Entropy: 분포를 더 잘 반영하는 경우가 많음(모델/도메인 의존)
  • Percentile: outlier를 일정 비율로 클리핑해 해상도를 확보

ONNXRT 버전에 따라 지원 옵션이 다를 수 있으니, 사용 중인 onnxruntime.quantization 문서를 확인하세요.

실무 팁:

  • 입력 분포가 넓고 outlier가 잦으면 Percentile 계열이 유리한 경우가 많음
  • 분류/검출처럼 logit 마진이 중요한 모델은 클리핑이 오히려 정확도를 깎을 수도 있어 A/B로 확인 필요

5) 연산자 커버리지: “INT8로 안 내려가는” 이유 찾기

양자화를 했는데 속도가 그대로면, 대부분 아래 중 하나입니다.

  1. EP가 해당 INT8 커널을 지원하지 않음
  2. 특정 노드가 FP32로 남아 그래프가 자주 디퀀타이즈됨
  3. 배치/형상 때문에 최적화가 안 걸림

ONNXRT 세션에서 실행 제공자 확인

import onnxruntime as ort

sess = ort.InferenceSession("model_int8_qdq.onnx", providers=["CPUExecutionProvider"])
print(sess.get_providers())

CPU에서라면 보통 CPUExecutionProvider로 충분하지만, 환경에 따라 빌드 옵션/가속 경로가 다를 수 있습니다.

프로파일링으로 병목 찾기

ONNXRT는 프로파일 기능을 제공합니다.

import onnxruntime as ort

so = ort.SessionOptions()
so.enable_profiling = True
sess = ort.InferenceSession("model_int8_qdq.onnx", sess_options=so, providers=["CPUExecutionProvider"])

# 더미 실행
import numpy as np
x = np.random.randn(1, 128, 768).astype(np.float32)
_ = sess.run(None, {"input": x})

profile_path = sess.end_profiling()
print("profile:", profile_path)

프로파일 JSON을 보면 어떤 노드가 시간을 먹는지, 어떤 커널로 실행되는지 단서가 나옵니다.

6) 정확도 튜닝: 민감 레이어 제외, per-channel 적용

정확도 손실이 크면 “무조건 더 많은 레이어를 INT8로”가 아니라, 민감한 구간은 FP16/FP32로 남기고 나머지를 INT8로 두는 전략이 효과적입니다.

대표적으로 민감한 후보:

  • LayerNorm 계열
  • Softmax 주변
  • 출력 head(특히 회귀)

ONNXRT 양자화에는 노드 제외(exclude) 옵션을 제공하는 경우가 많습니다(버전별 API 차이 존재). 개념적으로는 다음처럼 접근합니다.

  • 1차: 전체 INT8 시도
  • 2차: 정확도 깨지는 모델에서 특정 op 타입(LayerNorm/Softmax 등) 제외
  • 3차: 특정 노드 이름 단위로 제외(가장 민감한 블록만 FP로)

또한 weight는 per-channel quantization이 정확도에 유리한 경우가 많습니다(특히 Conv, Linear). 지원 여부에 따라 옵션을 켜서 비교하세요.

7) 성능 튜닝: Q/DQ 위치와 그래프 최적화

INT8 성능이 기대만큼 나오지 않는 흔한 패턴은 Q/DQ가 너무 자주 등장해 양자화-역양자화 오버헤드가 커지는 경우입니다.

개선 아이디어:

  • 가능한 한 큰 서브그래프가 INT8로 유지되도록(EP 커널 지원이 전제)
  • 동적 축을 최소화하거나, 서빙에서 실제로 쓰는 shape로 고정한 별도 엔진을 만들기
  • CPU에서는 thread 설정을 고정하고, 측정 시 warm-up을 충분히 수행

간단한 벤치마크 스켈레톤:

import time
import numpy as np
import onnxruntime as ort

def bench(model_path, iters=200, warmup=20):
    sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
    x = np.random.randn(1, 128, 768).astype(np.float32)

    for _ in range(warmup):
        sess.run(None, {"input": x})

    t0 = time.perf_counter()
    for _ in range(iters):
        sess.run(None, {"input": x})
    t1 = time.perf_counter()

    return (t1 - t0) / iters

print("fp32 avg sec:", bench("model_fp32.onnx"))
print("int8 avg sec:", bench("model_int8_qdq.onnx"))

주의:

  • 측정은 반드시 동일 머신/동일 코어 상태에서 수행
  • OS 스케줄링 영향이 크면 분산이 커지므로 여러 번 반복 후 중앙값을 보세요

8) 흔한 오류/함정 체크리스트

1) 입력 dtype을 FP32로 넣어도 되나?

Q/DQ 모델은 보통 입력이 FP32여도 내부에서 QuantizeLinear로 INT8로 내려갑니다. 다만 “입력부터 INT8로 넣고 싶다”면 모델 입력 스펙을 바꿔야 하고, 전처리 단계에서 scale/zero-point를 맞춰야 해서 운영 복잡도가 올라갑니다. 대부분은 FP32 입력 유지가 실용적입니다.

2) 정확도 비교를 제대로 하고 있나?

  • 전처리 동일성
  • 랜덤 시드 고정
  • 배치/시퀀스 길이 분포 동일성
  • Top-1만 보지 말고 task에 맞는 지표(예: mAP, F1, perplexity 등)로 확인

3) 모델이 “INT8인데 느리다”

  • 실제로는 대부분 FP32로 실행되고 있을 수 있음(커버리지 확인)
  • 작은 배치에서는 오히려 오버헤드가 커져 손해일 수 있음
  • 특정 연산(GELU, LayerNorm 등)이 병목이면 INT8 이득이 제한적일 수 있음

4) 재현 가능한 실험 파이프라인

양자화 튜닝은 옵션이 많아 실험이 쉽게 산으로 갑니다. 실험 로그/아티팩트 캐시 전략은 CI에서도 중요합니다. 반복 실험이 많다면 GitHub Actions 캐시 미적중? 키 설계 7원칙처럼 캐시 키를 잘 설계해 시간을 줄이세요.

9) 실전 권장 워크플로우

  1. FP32 ONNX 내보내기(동적 축 최소화 가능하면 최소화)
  2. 캘리브레이션 세트 준비(실제 트래픽 대표)
  3. ONNXRT static Q/DQ로 전체 양자화
  4. 정확도 측정(태스크 지표)
  5. 성능 측정(프로파일링 포함)
  6. 정확도 깨지면
    • 캘리브레이션 방식 변경
    • per-channel 적용
    • 민감 레이어/op 제외
  7. 성능이 안 나오면
    • INT8 커버리지 확인
    • 병목 op 확인 후 구조/EP/shape 전략 재검토

마무리

PyTorch 2.x 환경에서 PTQ로 INT8을 “적용”하는 것과, ONNXRT에서 “제대로 빠르게” 만드는 것은 별개의 문제입니다. 캘리브레이션 데이터 품질, Q/DQ 커버리지, EP 커널 지원, shape 전략(동적 축), 그리고 민감 레이어 제외 같은 튜닝 포인트를 체계적으로 실험하면 정확도 손실을 통제하면서도 CPU 추론 비용을 유의미하게 낮출 수 있습니다.

다음 단계로는

  • 모델별로 민감 구간을 자동 탐색하는 ablation 스크립트 작성
  • 서빙에서 FP32/INT8를 카나리로 비교 배포
  • 입력 분포 드리프트에 대비한 주기적 재캘리브레이션

같은 운영 전략까지 묶어가면, 양자화가 “일회성 실험”이 아니라 “지속 가능한 최적화 루프”가 됩니다.