- Published on
PyTorch 모델을 ONNX+INT8로 경량화해 3배 빠르게
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙 환경에서 모델 추론이 병목이 되는 순간은 대개 두 가지입니다. 첫째, PyTorch 런타임 그대로 CPU에서 돌리면 연산자 최적화가 제한적이라 지연 시간이 늘어납니다. 둘째, FP32 가중치와 activation을 그대로 쓰면 메모리 대역폭과 캐시 미스가 크게 증가해 처리량이 떨어집니다. 이 글에서는 PyTorch -> ONNX -> INT8(PTQ) 파이프라인으로 모델을 경량화하고, CPU 기반 추론에서 속도를 3배 수준까지 끌어올리는 과정을 실전 관점에서 정리합니다.
핵심은 두 단계입니다.
- ONNX로 내보내서 실행 엔진(ONNX Runtime)의 그래프 최적화와 커널을 활용
- INT8 양자화로 연산량과 메모리 대역폭을 줄여 CPU에서 특히 큰 이득 확보
아래 내용은 이미지/텍스트 모델 모두에 적용 가능하지만, 예시는 가장 흔한 torchvision 분류 모델로 설명합니다.
목표 아키텍처: PyTorch 학습, ONNX Runtime 서빙
PyTorch는 학습에 강하고, ONNX Runtime은 추론 최적화에 강합니다. 특히 CPU에서는 ONNX Runtime이 제공하는 최적화(연산자 fusion, 상수 folding, layout 최적화 등)와 INT8 커널이 큰 차이를 만듭니다.
서빙 관점에서의 장점은 다음과 같습니다.
- Python 의존도를 줄이거나(선택) 추론 경로를 단순화
- 모델 교체를 파일 단위로 관리(
.onnx) - CPU 인스턴스에서 비용 대비 성능 개선
컨테이너 배포까지 고려한다면, 이미지 크기와 콜드스타트도 함께 줄이는 것이 좋습니다. 모델이 가벼워져도 런타임 이미지가 무거우면 체감이 줄 수 있으니, 배포 단계 최적화는 Docker 이미지 80% 줄이기 - distroless+SBOM+SLSA도 같이 참고하면 좋습니다.
준비물: 버전과 환경 체크
다음 조합이 무난합니다.
torch,torchvisiononnxonnxruntime(CPU)onnxruntime-tools또는onnxruntime.quantization
설치 예시는 아래처럼 진행합니다.
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install onnx onnxruntime
CPU만 다룬다면 위로 충분합니다. (GPU INT8은 TensorRT 등 다른 경로가 더 일반적입니다.)
1단계: PyTorch 모델을 ONNX로 Export
Export 시 가장 중요한 3가지
model.eval()필수: Dropout, BatchNorm 동작이 추론 모드로 고정dynamic_axes설정: 배치 크기 가변 처리- 입력/출력 이름 고정: 서빙 코드에서 안정적으로 바인딩
아래는 ResNet18을 ONNX로 내보내는 예시입니다.
import torch
import torchvision
model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
model.eval()
dummy = torch.randn(1, 3, 224, 224)
onnx_path = "resnet18_fp32.onnx"
torch.onnx.export(
model,
dummy,
onnx_path,
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=["input"],
output_names=["logits"],
dynamic_axes={
"input": {0: "batch"},
"logits": {0: "batch"},
},
)
print("saved:", onnx_path)
Export가 깨질 때 흔한 원인
- 지원되지 않는 연산자(opset 문제):
opset_version을 올리거나 모델 구조 변경 필요 - 동적 shape가 복잡한 모델: 우선 고정 shape로 export 후 점진적으로 dynamic 적용
- 커스텀 연산: ONNX 변환 불가 케이스가 많아 대체 구현이 필요
이 단계에서 모델이 정상적으로 로드되고 inference 가능한지부터 확인해야 합니다.
2단계: ONNX Runtime로 FP32 성능 기준선 잡기
INT8로 가기 전에 FP32 기준선을 잡아야 “3배”가 진짜인지 판단할 수 있습니다. 또한 양자화 후 성능이 안 나올 때 원인을 분리할 수 있습니다.
import time
import numpy as np
import onnxruntime as ort
sess = ort.InferenceSession(
"resnet18_fp32.onnx",
providers=["CPUExecutionProvider"],
)
def bench(batch=1, warmup=10, iters=100):
x = np.random.randn(batch, 3, 224, 224).astype(np.float32)
# warmup
for _ in range(warmup):
_ = sess.run(["logits"], {"input": x})
t0 = time.perf_counter()
for _ in range(iters):
_ = sess.run(["logits"], {"input": x})
t1 = time.perf_counter()
avg_ms = (t1 - t0) * 1000 / iters
return avg_ms
print("fp32 avg ms:", bench(batch=1))
여기서 얻은 평균 지연 시간(ms)이 기준선입니다.
3단계: INT8 양자화(PTQ) 적용
INT8로 가는 방법은 크게 두 가지입니다.
- Dynamic Quantization: 가중치 중심으로 양자화, calibration 데이터 불필요, 주로
MatMul/Gemm계열에 효과 - Static Quantization: 가중치와 activation까지 양자화, calibration 데이터 필요, CNN 계열에서 더 큰 이득 가능
이미지 분류 같은 CNN 모델은 보통 Static이 더 큰 가속을 줍니다. 여기서는 실전에서 가장 많이 쓰는 ONNX Runtime의 PTQ(사후 양자화)를 사용합니다.
(A) Dynamic Quantization 예시
구현이 간단하고 실패 확률이 낮습니다. 다만 CNN에서는 기대보다 이득이 작을 수 있습니다.
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input="resnet18_fp32.onnx",
model_output="resnet18_int8_dynamic.onnx",
weight_type=QuantType.QInt8,
)
print("saved: resnet18_int8_dynamic.onnx")
(B) Static Quantization 예시 (권장)
Static은 calibration 데이터로 activation 범위를 추정합니다. 이 데이터는 정답 라벨이 필요 없고, 입력 분포만 비슷하면 됩니다.
import os
import numpy as np
from onnxruntime.quantization import (
quantize_static,
CalibrationDataReader,
QuantType,
QuantFormat,
)
class ImageNetLikeDataReader(CalibrationDataReader):
def __init__(self, n=200, batch=1):
self.batch = batch
self.data = [
{"input": np.random.randn(batch, 3, 224, 224).astype(np.float32)}
for _ in range(n)
]
self._iter = iter(self.data)
def get_next(self):
return next(self._iter, None)
def rewind(self):
self._iter = iter(self.data)
reader = ImageNetLikeDataReader(n=200, batch=1)
quantize_static(
model_input="resnet18_fp32.onnx",
model_output="resnet18_int8_static.onnx",
calibration_data_reader=reader,
quant_format=QuantFormat.QOperator,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8,
)
print("saved: resnet18_int8_static.onnx")
실서비스라면 np.random 대신 실제 전처리된 샘플(예: 최근 1~2일 트래픽에서 추출한 입력)을 넣는 것이 정확도 방어에 유리합니다.
4단계: INT8 성능 측정과 3배 가속 검증
같은 벤치 함수를 INT8 모델에 적용합니다.
import time
import numpy as np
import onnxruntime as ort
def bench_model(path, batch=1, warmup=10, iters=100):
sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
x = np.random.randn(batch, 3, 224, 224).astype(np.float32)
for _ in range(warmup):
_ = sess.run(["logits"], {"input": x})
t0 = time.perf_counter()
for _ in range(iters):
_ = sess.run(["logits"], {"input": x})
t1 = time.perf_counter()
return (t1 - t0) * 1000 / iters
fp32 = bench_model("resnet18_fp32.onnx")
int8 = bench_model("resnet18_int8_static.onnx")
print("fp32 ms:", fp32)
print("int8 ms:", int8)
print("speedup x:", fp32 / int8)
일반적으로 “3배”는 다음 조건에서 현실적입니다.
- CPU에서 동작 (특히 AVX2/VNNI 지원 서버)
- Conv 중심 모델에서 static quantization이 잘 먹힘
- 배치가 너무 크지 않음(지연 시간 최적화 기준)
반대로, 이미 GPU에서 FP16으로 잘 최적화된 경우라면 ONNX+INT8이 항상 이득을 주지는 않습니다.
정확도 하락을 통제하는 방법
INT8은 성능을 얻는 대신 정확도가 소폭 떨어질 수 있습니다. 이를 통제하는 실전 체크리스트는 아래와 같습니다.
1) Calibration 데이터 품질이 전부다
- 실제 입력 분포와 유사해야 함
- 전처리(리사이즈, 정규화, 채널 순서)가 서빙과 동일해야 함
- 너무 적으면 범위 추정이 불안정해 정확도 하락이 커질 수 있음
보통 수백~수천 샘플이면 시작하기 충분하고, 민감한 모델은 더 늘립니다.
2) 민감 레이어를 FP32로 남기는 전략
모든 노드를 INT8로 강제하면 손해인 경우가 있습니다. 예를 들어 첫 Conv, 마지막 FC, LayerNorm 주변은 민감할 수 있습니다. ONNX Runtime 양자화 옵션으로 제외 노드를 지정하는 방식이 실전에서 자주 쓰입니다.
환경에 따라 API가 조금씩 다를 수 있지만, 핵심은 “부분 양자화”입니다.
3) 모델 구조에 따라 Dynamic이 더 나을 때
Transformer 계열(특히 MatMul 비중이 큰 모델)은 dynamic quantization만으로도 큰 이득이 나기도 합니다. CNN에 static을 권장했지만, 본인 모델의 연산 분포를 보고 선택하는 것이 정석입니다.
ONNX Runtime 세션 옵션으로 추가 최적화
모델을 INT8로 만들었는데도 기대만큼 안 빨라지면, 세션 옵션과 스레딩이 병목일 수 있습니다.
import onnxruntime as ort
so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess = ort.InferenceSession(
"resnet18_int8_static.onnx",
sess_options=so,
providers=["CPUExecutionProvider"],
)
추가로 다음을 점검하세요.
- 컨테이너 CPU 제한(
cpu limit)이 너무 낮아 스레드가 못 뜨는지 - 동일 머신에서 다른 워크로드와 CPU 경쟁이 있는지
- 배치 크기와 동시성(멀티프로세스/멀티스레드) 조합이 적절한지
서버리스나 컨테이너 환경에서는 콜드스타트/오토스케일링이 체감 성능을 크게 좌우합니다. 추론 최적화와 별개로 플랫폼 지연이 문제라면 Cloud Run 503·콜드스타트 7분 지연 해결 가이드처럼 런타임/프로비저닝 튜닝도 함께 보세요.
실전 배포 팁: 재현 가능한 벤치와 롤백
성능 최적화는 “측정 가능한 형태”로 남겨야 팀이 안전하게 운영합니다.
- 동일 입력으로 FP32와 INT8의 출력 차이를 로그로 남기기(예: top-1 class, cosine similarity)
- p50/p95 지연 시간과 QPS를 함께 기록
- 모델 파일 버전 관리(해시) 및 즉시 롤백 가능하게 구성
특히 운영 환경에서 성능 이슈를 추적할 때는 “모델 자체 문제”와 “인프라 문제”를 분리해야 합니다. 쿠버네티스 기반이라면 네트워크/DNS 이슈가 지연 시간으로 보이는 경우도 있어, 장애 대응 관점에서는 EKS CoreDNS 장애? DNS 타임아웃 8단계 같은 체크리스트가 의외로 도움이 됩니다.
자주 겪는 문제와 해결
INT8로 바꿨는데 속도가 거의 안 오른다
- Dynamic만 적용했는데 Conv 비중이 큰 모델일 수 있음: static 시도
- CPU가 INT8 가속 명령어에 약함: 인스턴스 타입 확인(서버 CPU 세대)
- 배치가 너무 크거나 반대로 너무 작아 오버헤드가 지배: 배치/동시성 재조정
- 전처리 시간이 더 크다: 전처리까지 포함해 end-to-end로 측정
정확도가 눈에 띄게 떨어진다
- calibration 데이터가 실제 분포와 다름
- 전처리가 불일치
- 민감 레이어까지 전부 INT8로 내려감: 부분 양자화 고려
ONNX export 자체가 안 된다
- opset 올리기
- 지원 연산자로 치환
- 모델을 더 단순한 형태로 trace 가능하게 정리
정리: 3배 가속을 만드는 체크리스트
- FP32 ONNX 기준선 측정 후 INT8 비교
- CNN 계열은 static quantization 우선 검토
- calibration 데이터는 “실제 입력” 기반으로 구성
- 필요하면 민감 레이어는 FP32 유지
- 세션 옵션과 스레딩, 컨테이너 CPU 제한까지 함께 튜닝
이 과정을 제대로 밟으면, CPU 서빙에서 ONNX Runtime의 최적화와 INT8의 메모리/연산 이득이 합쳐져 “속도 3배”는 충분히 현실적인 목표가 됩니다. 모델별로 편차가 크니, 위 코드로 먼저 재현 가능한 벤치를 만든 뒤 점진적으로 적용 범위를 넓히는 방식이 가장 안전합니다.