Published on

PyTorch INT8 PTQ로 추론 4배 가속하는 법

Authors

서빙 환경에서 모델 추론이 병목일 때, 가장 빠르게 체감 성능을 올리는 방법 중 하나가 INT8 양자화입니다. 특히 학습을 다시 하지 않고도 적용 가능한 PTQ(Post-Training Quantization)는 “지금 있는 FP32 모델을 최대한 유지한 채” 추론 비용을 낮추는 데 적합합니다. 잘 맞는 모델과 하드웨어 조합에서는 CPU 기준으로 지연(latency)과 처리량(throughput) 모두에서 2배에서 4배 수준의 개선을 얻는 사례가 흔합니다.

이 글은 PyTorch에서 INT8 PTQ를 적용하는 실전 흐름을 중심으로, 정확도 하락을 줄이는 캘리브레이션 방법, 흔한 함정(레이어 미지원, 동적/정적 양자화 선택, 성능이 안 나오는 이유)까지 정리합니다.

INT8 PTQ가 빠른 이유: 메모리 대역폭과 벡터화

FP32는 가중치와 활성값이 4바이트입니다. INT8은 1바이트라서 같은 연산을 하더라도:

  • 메모리 읽기/쓰기 트래픽이 대폭 감소
  • CPU의 INT8 벡터 명령어(예: AVX2/VNNI 등) 경로를 활용 가능
  • 특히 GEMM(행렬곱) 비중이 큰 모델에서 효과가 큼

다만 “항상 4배”는 아닙니다. 실제 가속은 아래에 크게 좌우됩니다.

  • CPU가 INT8 가속 경로를 얼마나 잘 지원하는지
  • 모델이 양자화 친화적인 연산(Conv, Linear) 위주인지
  • 전처리/후처리, softmax, layernorm, embedding 등 FP32로 남는 구간이 얼마나 되는지

PTQ 종류: 동적 양자화 vs 정적 양자화

PyTorch PTQ는 크게 두 갈래가 있습니다.

동적 양자화(Dynamic Quantization)

  • 주로 nn.Linear 중심(Transformer의 FFN/Projection 등)에 적용
  • 활성값(activation)은 런타임에 스케일을 추정해 양자화
  • 캘리브레이션 데이터가 없어도 적용 가능
  • 정확도 손실이 상대적으로 작고 적용이 간단
  • CPU에서 효과가 좋고, 특히 NLP 계열에서 많이 씀

정적 양자화(Static/PTQ with Calibration)

  • Conv/Linear 모두 폭넓게 가능
  • 활성값 스케일을 캘리브레이션 데이터로 미리 추정
  • 준비 과정(관측자 삽입, 캘리브레이션 실행, 변환)이 필요
  • 성능/정확도 모두 잘 나오는 케이스가 많지만 세팅 난이도는 더 높음

정리하면, Transformer 계열이면 동적 양자화부터 시작하고, CNN/비전 계열이면 정적 양자화를 우선 고려하는 접근이 현실적입니다.

적용 전 체크리스트

PTQ를 “변환만 했는데 빨라지지 않거나 오히려 느려지는” 상황을 피하려면 아래를 먼저 확인하세요.

  • 벤치마크는 반드시 torch.inference_mode() 로 측정
  • CPU 스레드 수/바인딩 설정(예: torch.set_num_threads)이 일관적인지
  • 모델이 eval() 모드인지
  • 입력 배치 크기와 실제 서빙 패턴이 유사한지
  • 성능 병목이 모델 연산인지(전처리, 토크나이즈, I/O가 병목이면 양자화 효과가 제한됨)

프로덕션에서 CPU 사용량이 치솟아 Pod가 불안정해지는 경우도 흔합니다. 그런 경우에는 양자화 자체뿐 아니라 리소스 제한과 OOM/재시작 패턴도 함께 보세요. 관련해서는 K8s CrashLoopBackOff 원인별 진단 체크리스트EKS Pod OOMKilled 반복 원인과 메모리·GC·Limit 튜닝이 같이 도움이 됩니다.

실습 1: Linear 중심 모델에 동적 INT8 양자화 적용

아래 예시는 nn.Linear가 많은 모델(간단한 MLP, Transformer의 일부 블록 등)에 동적 양자화를 적용하는 최소 코드입니다.

import time
import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, in_dim=1024, hidden=4096, out_dim=1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x):
        return self.net(x)


def bench(model, x, iters=200, warmup=50):
    model.eval()
    with torch.inference_mode():
        for _ in range(warmup):
            _ = model(x)
        t0 = time.perf_counter()
        for _ in range(iters):
            _ = model(x)
        t1 = time.perf_counter()
    return (t1 - t0) / iters


torch.set_num_threads(8)
model_fp32 = MLP().eval()

# 동적 양자화: Linear에 대해 INT8 weight + 런타임 activation quant 적용
model_int8 = torch.ao.quantization.quantize_dynamic(
    model_fp32,
    {nn.Linear},
    dtype=torch.qint8,
)

x = torch.randn(32, 1024)

fp32_t = bench(model_fp32, x)
int8_t = bench(model_int8, x)

print(f"FP32: {fp32_t*1000:.3f} ms")
print(f"INT8: {int8_t*1000:.3f} ms")
print(f"Speedup: {fp32_t/int8_t:.2f}x")

동적 양자화에서 자주 하는 실수

  • dtypetorch.qint8로 주지 않거나, 양자화 대상 레이어 set을 잘못 지정
  • 실제 병목이 Linear가 아닌데 Linear만 양자화해서 효과가 제한됨
  • 배치가 너무 작아 오버헤드가 상대적으로 커짐

실습 2: Conv 모델에 정적 INT8 PTQ 적용(캘리브레이션 포함)

정적 양자화는 “관측자(observer)를 삽입하고 캘리브레이션 데이터로 통계를 모은 뒤 변환(convert)”하는 흐름입니다.

아래는 예제로 작은 ConvNet을 대상으로 한 전형적인 파이프라인입니다. 실제 서비스에서는 캘리브레이션 데이터셋을 “실제 트래픽을 대표하는 샘플”로 구성하는 것이 정확도에 매우 중요합니다.

import torch
import torch.nn as nn

class SmallConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 32, 3, stride=2, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(32, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def calibrate(model, data_loader, num_batches=20):
    model.eval()
    with torch.inference_mode():
        for i, (x, _) in enumerate(data_loader):
            _ = model(x)
            if i + 1 >= num_batches:
                break


# 1) FP32 모델 준비
model = SmallConv().eval()

# 2) 정적 양자화 준비: qconfig 설정
# backend는 환경에 따라 fbgemm(서버 x86) 또는 qnnpack(모바일/ARM) 사용
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")

# 3) fuse: Conv+ReLU 같이 fuse 가능한 패턴을 합쳐 성능/정확도에 유리
# 이 예제는 모듈이 분리되어 있으므로 fuse 가능한 리스트를 지정
fused = torch.ao.quantization.fuse_modules(model, [["conv", "relu"]], inplace=False)

# 4) prepare: observer 삽입
prepared = torch.ao.quantization.prepare(fused, inplace=False)

# 5) calibration: 대표 데이터로 통계 수집
# data_loader는 (image, label) 배치를 반환한다고 가정
# calibrate(prepared, data_loader)

# 여기서는 예시로 랜덤 텐서로 캘리브레이션을 흉내냄
prepared.eval()
with torch.inference_mode():
    for _ in range(20):
        _ = prepared(torch.randn(32, 3, 224, 224))

# 6) convert: INT8 모델로 변환
quantized = torch.ao.quantization.convert(prepared, inplace=False)

# 7) 추론 테스트
x = torch.randn(8, 3, 224, 224)
with torch.inference_mode():
    y = quantized(x)
print(y.shape)

정적 양자화에서 성능이 안 나오는 대표 원인

  • fuse를 안 해서 Conv와 활성함수가 분리 실행됨
  • 캘리브레이션 데이터가 실제 분포를 대표하지 못해 스케일이 망가짐
  • 양자화가 적용되지 않는 연산이 많아 INT8 구간 비율이 낮음

정확도 하락을 줄이는 캘리브레이션 전략

PTQ의 성패는 “스케일과 제로포인트 추정이 얼마나 잘 되었는가”로 갈립니다. 아래는 실무에서 바로 쓰는 팁입니다.

1) 캘리브레이션 데이터는 적어도 수백에서 수천 샘플

정적 양자화는 20배치 같은 장난감 캘리브레이션으로는 실제 정확도가 흔들릴 수 있습니다. 클래스 분포, 길이 분포(텍스트), 밝기/노이즈(이미지) 등 서빙 환경을 반영하세요.

2) outlier가 많은 분포는 퍼센타일/히스토그램 기반이 유리

기본 observer는 min/max 기반이 많아 outlier에 약할 수 있습니다. 모델과 PyTorch 버전에 따라 히스토그램 기반 observer를 선택하는 방식도 고려합니다.

3) 레이어별로 “양자화 제외”를 전략적으로 적용

LayerNorm, Softmax, 일부 Embedding, 특정 커스텀 연산은 INT8로 바꾸면 정확도 손실이 커질 수 있습니다. 전체를 무리하게 INT8로 만들기보다, 병목이 큰 Linear/Conv를 우선 양자화하고 민감한 레이어는 FP32로 남기는 것이 총합 성능과 품질의 균형이 좋습니다.

벤치마크: 4배 가속을 재현하려면 무엇을 맞춰야 하나

“4배 가속”은 대개 아래 조건이 맞을 때 가능합니다.

  • CPU가 INT8 가속 명령어를 지원하고, PyTorch가 해당 backend를 잘 활용
  • 모델의 대부분 시간이 Linear/Conv GEMM에 쓰임
  • 배치 크기가 너무 작지 않음(오버헤드 대비 연산량이 충분)
  • 스레드 수가 적절하고, NUMA/코어 바인딩이 안정적

반대로, 토크나이저/전처리 비중이 크거나, attention/softmax 같은 FP32 구간이 지배적이면 체감 가속은 제한됩니다. 이때는 양자화 외에 런타임 자체를 바꾸는 접근(예: ONNX Runtime)도 같이 검토할 수 있습니다. LLM 쪽 양자화/런타임 최적화는 ONNX Runtime로 LLM INT4 양자화와 지연 개선도 참고할 만합니다.

운영 관점: 모델만 빠르게 해서는 끝이 아니다

INT8로 CPU 사용량과 지연이 줄면, 다음 단계로는 “같은 인프라에서 더 많은 요청을 처리”하게 됩니다. 여기서 새로운 병목이 생길 수 있습니다.

  • 컨테이너 메모리 제한이 낮아져도 안정적인지(캐시, 배치, 워커 수 변화)
  • HPA가 CPU 기반이면 스케일 정책이 바뀌어야 하는지
  • 로깅/메트릭/트레이싱이 새 병목이 되는지

특히 K8s에서 워커 수를 늘리다 CrashLoopBackOff나 OOMKilled가 반복되면, 성능 튜닝만으로는 해결이 안 됩니다. 위에서 언급한 운영 트러블슈팅 글들을 같이 보는 것이 안전합니다.

결론: 가장 현실적인 적용 순서

실무에서 실패 확률을 낮추는 순서는 보통 이렇게 갑니다.

  1. CPU 추론 병목을 측정하고, 전처리/후처리 비중을 분리
  2. Transformer 계열은 동적 양자화로 빠르게 효과 확인
  3. CNN 계열은 fuse 포함 정적 양자화 + 캘리브레이션 데이터 품질 확보
  4. 정확도 하락이 크면 민감 레이어를 FP32로 남기고 병목 레이어 위주로 재조정
  5. 그래도 부족하면 ONNX Runtime 등 다른 런타임/그래프 최적화까지 확장

PTQ는 “학습 없이도” 적용 가능한 대신, 캘리브레이션과 벤치마크가 품질을 좌우합니다. 위 코드 템플릿을 기반으로, 본인 모델 구조(Linear/Conv 비중)와 실제 트래픽 분포를 반영해 세팅을 다듬으면 INT8만으로도 충분히 4배에 가까운 가속을 노려볼 수 있습니다.