- Published on
PyTorch에서 TensorRT INT8로 3배 가속하기
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙 환경에서 GPU가 병목이 되는 순간은 대개 두 가지입니다. 첫째, 모델이 커서 FP16으로도 지연시간이 기대만큼 줄지 않을 때. 둘째, 트래픽이 늘어 배치/동시성을 올리면 메모리 대역폭과 커널 런치 오버헤드가 급격히 커질 때입니다. 이때 가장 현실적인 카드가 TensorRT 최적화이고, 그중에서도 INT8 양자화는 잘만 적용하면 동일 GPU에서 2~4배 수준의 처리량 개선을 자주 만들어냅니다(모델 구조·입출력 크기·레이어 구성에 따라 편차 큼).
이 글은 PyTorch -> ONNX -> TensorRT 파이프라인으로 INT8 엔진을 만들고, 정확도 하락을 통제하면서 3배 가속에 근접시키는 데 필요한 핵심 포인트를 실전 관점에서 정리합니다.
왜 INT8이 빨라지나: 병목 관점에서 이해하기
INT8은 단순히 “정밀도를 낮추는 것”이 아니라, 다음 두 가지 효과를 동시에 노립니다.
메모리 대역폭 절감
- 같은 텐서라도
FP16대비INT8은 2배 더 작고,FP32대비 4배 더 작습니다. - 대형 feature map을 많이 다루는 CNN/비전 모델에서 효과가 큽니다.
- 같은 텐서라도
Tensor Core / INT8 경로 활용
- NVIDIA GPU는 특정 연산(특히 GEMM/Conv)에 대해 INT8 최적 경로를 제공합니다.
- TensorRT는 레이어 fusion, kernel auto-tuning, precision calibration을 통해 이 경로를 최대한 사용합니다.
다만 모든 레이어가 INT8로 떨어지는 건 아닙니다. 일부 레이어는 FP16/FP32로 남을 수 있고, 그 구간이 전체 지연시간을 잡아먹으면 기대한 만큼 빨라지지 않습니다. 그래서 “엔진 빌드 옵션 + 캘리브레이션 + 모델 구조”를 함께 봐야 합니다.
전체 파이프라인 개요
아래 순서가 가장 일반적입니다.
- PyTorch 모델을
eval()로 고정하고 ONNX로 export - ONNX 그래프를 정리(옵션:
onnxsim, shape 고정/동적 설정) - TensorRT에서
INT8 + (가능하면) FP16로 엔진 빌드 - 캘리브레이션 데이터로 스케일(quant scale) 산출
- 정확도(Top-1, mAP, BLEU 등)와 지연시간/처리량 측정
- 운영 배포(엔진 캐시, 드라이버/버전 고정, 모니터링)
이 글의 예시는 Python API를 기준으로 설명하되, 현장에서 자주 쓰는 trtexec도 함께 다룹니다.
1) PyTorch 모델을 ONNX로 내보내기
모델 export의 목표는 “TensorRT가 최적화하기 좋은 그래프”를 만드는 것입니다. 가장 중요한 건 다음입니다.
model.eval()및torch.no_grad()- 입력 shape를 명확히(정적 shape가 가장 쉬움)
- 불필요한 control flow 제거
- opset은 보통
13이상 권장(모델에 따라 상이)
import torch
def export_onnx(model, onnx_path="model.onnx", batch=1, h=224, w=224):
model.eval()
dummy = torch.randn(batch, 3, h, w, device="cpu")
with torch.no_grad():
torch.onnx.export(
model,
dummy,
onnx_path,
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch"},
"output": {0: "batch"},
},
)
# 사용 예
# export_onnx(my_model, "resnet.onnx")
ONNX 내보내기에서 흔한 함정
dynamic_axes를 너무 넓게 열면 TensorRT 최적화가 약해질 수 있습니다.- 반대로 완전 정적 shape로 고정하면 가장 빠르지만, 배치/해상도 변화가 있는 서비스에서는 불편합니다.
- 실무에서는
batch만 동적으로 열고, 이미지 해상도는 고정하는 구성이 타협점인 경우가 많습니다.
2) TensorRT INT8의 핵심: 캘리브레이션(PTQ)
대부분의 팀은 학습 단계에서 QAT(Quantization Aware Training)까지 가지 않고 PTQ(Post-Training Quantization) 로 먼저 성과를 냅니다. TensorRT INT8 PTQ의 본질은 다음입니다.
- 대표 입력(캘리브레이션 데이터)을 흘려보내서 activation 분포를 관찰
- 각 텐서(혹은 레이어)에 대해
scale을 정해INT8로 매핑
캘리브레이션 데이터는 “정확도”를 만드는 재료입니다.
- 데이터는 학습 데이터와 동일 분포가 이상적
- 보통 100~1,000 샘플로 시작(모델 민감도에 따라 더 필요)
- 전처리(정규화, 리사이즈, 색공간)가 서빙과 완전히 동일해야 함
전처리가 조금만 달라도 INT8에서는 정확도 하락이 크게 튈 수 있습니다.
3) trtexec로 빠르게 성능 확인하기
코드를 붙이기 전에, trtexec로 “이 모델이 INT8에서 이득이 나는지”를 먼저 확인하는 게 효율적입니다.
다음 예시는 ONNX를 입력으로 INT8 엔진을 만들고 벤치마크하는 기본 형태입니다.
trtexec \
--onnx=model.onnx \
--saveEngine=model_int8.engine \
--int8 \
--fp16 \
--workspace=4096 \
--warmUp=200 \
--iterations=1000 \
--verbose
--fp16을 같이 켜는 이유: 일부 레이어는 INT8로 못 내리거나, 혼합 정밀도가 더 빠른 경우가 많습니다.--workspace는 빌드 시 탐색 가능한 알고리즘 폭에 영향을 줍니다(크면 빌드 시간/메모리 증가).
캘리브레이션 데이터 지정
trtexec로도 캘리브레이션 캐시를 만들 수 있지만, 전처리를 정확히 맞추려면 Python calibrator를 직접 구현하는 편이 안전합니다.
4) Python으로 INT8 엔진 빌드(캘리브레이터 포함)
TensorRT Python API에서 INT8을 쓰려면 IInt8Calibrator를 구현해 배치를 공급하고, 캘리브레이션 캐시를 저장/재사용합니다.
아래 코드는 구조 이해를 위한 최소 예시입니다(실무에서는 에러 처리, 스트림, pinned memory 등을 추가).
import os
import numpy as np
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
class ImageBatchCalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, batches, cache_file="calib.cache"):
super().__init__()
self.batches = batches # list of numpy arrays, shape: (N, C, H, W)
self.cache_file = cache_file
self.index = 0
# 단일 배치 크기로 가정
self.device_input = cuda.mem_alloc(self.batches[0].nbytes)
def get_batch_size(self):
return self.batches[0].shape[0]
def get_batch(self, names):
if self.index >= len(self.batches):
return None
batch = self.batches[self.index]
cuda.memcpy_htod(self.device_input, batch)
self.index += 1
return [int(self.device_input)]
def read_calibration_cache(self):
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
return f.read()
return None
def write_calibration_cache(self, cache):
with open(self.cache_file, "wb") as f:
f.write(cache)
def build_int8_engine(onnx_path, engine_path, calibrator, workspace_mb=4096):
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(parser.get_error(i))
raise RuntimeError("ONNX parse failed")
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_mb * 1024 * 1024)
# 혼합 정밀도 권장
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.INT8)
config.int8_calibrator = calibrator
# 필요 시: strict 타입 강제는 오히려 느려지거나 빌드 실패를 유발할 수 있음
# config.set_flag(trt.BuilderFlag.STRICT_TYPES)
serialized = builder.build_serialized_network(network, config)
if serialized is None:
raise RuntimeError("Engine build failed")
with open(engine_path, "wb") as f:
f.write(serialized)
# calibrator에 넣을 batches는 반드시 실제 전처리 결과여야 함
# batches = [np.random.randn(8,3,224,224).astype(np.float32) for _ in range(50)]
# calib = ImageBatchCalibrator(batches)
# build_int8_engine("model.onnx", "model_int8.engine", calib)
캘리브레이션 전처리 체크리스트
- 입력 dtype은 대개
float32로 주고 TensorRT가 스케일을 학습합니다(모델 입력이 이미 normalize된 float라면 그대로). - 채널 순서
NCHW/NHWC불일치가 가장 흔한 실수입니다. - 이미지의
mean/std,RGB/BGR, 리사이즈 방식(bilinear, area), crop 규칙을 서빙과 동일하게 유지하세요.
5) “3배 가속”을 현실로 만드는 측정 방법
가속을 주장하려면 측정이 일관돼야 합니다. 최소한 아래는 지키는 게 좋습니다.
- 워밍업: 최소 수백 회(커널 캐시/클럭 안정화)
- 측정 반복: 1,000회 이상 권장
- 동시성: 실제 서비스와 유사하게(단일 요청만 빠른 건 의미가 약함)
- 배치: 서비스의 typical batch를 반영
TensorRT 엔진은 일반적으로 다음 순서로 성능이 나옵니다.
FP32 (PyTorch eager)가장 느림FP16 (TensorRT)빠름INT8 (TensorRT)가장 빠름(모델에 따라 FP16과 큰 차이 없을 수도)
특히 Transformer 계열은 matmul 최적화가 잘 되면 INT8 이득이 크지만, attention 패턴/플러그인/shape에 따라 편차가 큽니다.
6) 정확도 하락을 줄이는 실전 팁
INT8에서 가장 무서운 건 “평균 정확도는 괜찮은데 특정 케이스에서만 크게 깨지는” 현상입니다. 아래는 현장에서 효과가 좋았던 체크 포인트입니다.
6.1 캘리브레이션 데이터는 ‘대표성’이 전부
- 클래스 불균형이 큰 분류 문제라면, 소수 클래스 샘플을 일부러 포함하세요.
- 입력 길이/해상도 분포가 넓다면, 경계값(짧은 것, 긴 것)을 반드시 넣습니다.
6.2 민감 레이어는 FP16으로 남기기
TensorRT는 레이어별로 precision을 선택할 수 있습니다. 특정 레이어(예: 첫 conv, 마지막 linear, softmax 주변)가 민감하면 FP16으로 남기는 게 정확도-성능 균형에 도움이 됩니다.
다만 레이어 강제는 그래프/버전에 따라 API가 복잡해질 수 있어, 우선은 캘리브레이션 개선으로 해결을 시도하고, 마지막에 “부분 FP16 유지”를 검토하는 순서를 추천합니다.
6.3 동적 shape는 캘리브레이션 범위를 넓혀라
동적 배치/시퀀스 길이를 허용하면, INT8 스케일이 특정 길이에 과적합될 수 있습니다. 캘리브레이션 배치에 다양한 shape를 포함하거나, 운영에서 사용하는 최빈 shape 위주로 엔진을 여러 개 만들어 라우팅하는 방식도 고려할 만합니다.
7) 배포/운영에서 자주 터지는 문제
INT8 엔진은 “빌드만 되면 끝”이 아니라, 운영에서 다음 이슈가 자주 발생합니다.
- 드라이버/CUDA/TensorRT 버전 불일치로 엔진 로드 실패
- 컨테이너 재기동 시 엔진 재빌드로 인해 스타트업 지연
- GPU 메모리 부족(OOM)로 프로세스 크래시
특히 k8s에서 GPU 워크로드가 OOMKilled로 떨어지면 원인 파악이 번거롭습니다. 비슷한 유형의 장애 분석 흐름은 EKS CrashLoopBackOff - OOMKilled·Exit 137 원인과 해결 글의 체크리스트가 그대로 도움이 됩니다(메모리 리밋, 프로브, 노드 리소스, 로그 수집).
또한 엔진 빌드/로딩 이후 요청 타임아웃이 연쇄로 번지면 gRPC 기반 서비스에서는 장애가 증폭되기 쉽습니다. 타임아웃 전파를 설계적으로 막는 관점은 gRPC MSA에서 DEADLINE_EXCEEDED 연쇄 장애 차단도 함께 참고할 만합니다.
8) “정말 3배”가 나오는 조건
경험적으로 3배 수준이 잘 나오는 케이스는 다음 조합이 많습니다.
- 입력 텐서가 크고(예:
224x224이상), conv/GEMM 비중이 높음 - FP32 PyTorch eager를 기준선으로 잡았을 때
- TensorRT에서 fusion이 잘 일어나고, INT8로 내려가는 레이어 비율이 높음
- 캘리브레이션이 안정적이라 재시도 없이 최적 엔진이 한 번에 만들어짐
반대로 다음이면 3배가 어렵습니다.
- 이미 FP16 TensorRT로 충분히 최적화되어 있고, 병목이 다른 곳(전처리/후처리/IO)에 있음
- 작은 모델(커널 런치 오버헤드가 지배적)
- INT8로 못 내리는 연산이 많아 혼합 정밀도 구간이 길어짐
즉, INT8은 만능이 아니라 “GPU 연산이 병목인 모델”에서 가장 강력합니다. 전처리/후처리가 병목이면, 엔진만 바꿔서는 체감이 작습니다.
9) 최소 검증 플로우(추천)
실무에서 시간을 아끼는 검증 순서는 아래가 안전합니다.
- PyTorch FP32 기준 지연시간/처리량 측정
- TensorRT FP16 엔진으로 측정
- TensorRT INT8 엔진으로 측정
- 정확도 비교(샘플링이 아니라, 가능하면 검증 셋 전체)
- 운영 부하(동시성, 배치)에서 다시 측정
이 과정을 거치면 “INT8이 실제로 이득인지”와 “정확도 하락이 허용 가능한지”를 빠르게 결론낼 수 있습니다.
마무리
PyTorch -> TensorRT INT8는 비용 대비 효과가 큰 최적화 옵션이지만, 성패는 대부분 캘리브레이션 데이터 품질과 측정/운영 조건의 일치에서 갈립니다. 먼저 trtexec로 가능성을 확인하고, Python calibrator로 전처리를 엄격히 맞춘 뒤, FP16/INT8 혼합 정밀도로 엔진을 빌드해 보세요. 이 흐름만 제대로 잡아도 “GPU 한 대로 3배 처리”에 근접하는 케이스를 충분히 만들 수 있습니다.
원하시면 모델 종류(CNN, Transformer, detection 등)와 입력 shape, 목표 TPS/latency, GPU 기종을 알려주면 그 조건에 맞춰 캘리브레이션 샘플 수, 엔진 빌드 옵션, 병목 분해(전처리 포함)까지 더 구체적으로 튜닝 가이드를 이어서 정리해드릴 수 있습니다.