Published on

파이썬 ONNX Runtime로 BERT 3배 경량·고속화

Authors

서빙 환경에서 BERT류 모델을 그대로 PyTorch로 돌리면, 정확도는 만족스럽지만 지연 시간과 비용이 빠르게 병목이 됩니다. 특히 CPU 기반 배치 추론이나 API 형태의 단건 추론에서는 torch 런타임 오버헤드, 불필요한 연산 그래프, FP32 가중치로 인한 메모리 대역폭 한계가 누적됩니다.

이 글에서는 파이썬에서 ONNX Runtime을 사용해 BERT를 3배 수준으로 경량·고속화하는 것을 목표로, 아래 순서로 “재현 가능한” 튜닝 절차를 정리합니다.

  • PyTorch 또는 Hugging Face 모델을 ONNX로 내보내기
  • ONNX Runtime 그래프 최적화 설정
  • 동적 양자화로 모델 크기와 CPU 지연 시간 줄이기
  • I/O 바인딩과 세션 옵션으로 런타임 오버헤드 줄이기
  • 벤치마크 방법과 정확도 회귀 체크
  • 운영에서 자주 터지는 환경 이슈 체크리스트

참고: 환경이 꼬여 pip install 은 성공했는데 실행 시 ModuleNotFoundError 가 뜨는 경우가 생각보다 흔합니다. 가상환경 혼용을 10분 안에 정리하는 체크리스트는 pip install은 성공인데 실행하면 ModuleNotFoundError가 뜰 때 venv poetry conda 혼용으로 꼬인 인터프리터와 site-packages를 10분 만에 진단하고 확실히 고치는 체크리스트 를 먼저 확인해두면 삽질을 크게 줄일 수 있습니다.

목표 성능을 현실적으로 잡기

“3배”는 과장이 아니라, 조건이 맞으면 충분히 가능합니다.

  • CPU 추론: FP32 PyTorch 대비 ONNX Runtime GraphOptimizationLevel.ORT_ENABLE_ALL + 동적 양자화(QInt8) 조합이 가장 흔한 승리 플랜입니다.
  • 모델 크기: BERT base 계열은 FP32에서 대략 400MB 전후(가중치 기준)인데, INT8 양자화로 1/3 수준까지 줄어드는 경우가 많습니다.

다만 아래 조건에 따라 편차가 큽니다.

  • 입력 시퀀스 길이(예: 128 vs 512)
  • 배치 크기(단건 API vs 배치 추론)
  • CPU 종류와 스레딩(AVX2, AVX512, 코어 수)
  • 토크나이저가 병목인지 여부(대부분 토크나이저도 꽤 큼)

이 글에서는 BERT 계열의 분류 모델을 예로 들지만, 토큰 분류나 문장 임베딩도 같은 원리로 적용됩니다.

준비물: 버전과 설치

권장 조합(예시):

  • Python 3.10+
  • transformers
  • torch
  • onnx
  • onnxruntime 또는 onnxruntime-gpu
  • (선택) optimum 또는 onnxruntime-tools

설치 예시:

pip install -U transformers torch onnx onnxruntime
# 성능 측정
pip install -U numpy psutil

운영 서버에서 CPU 성능이 들쑥날쑥하면, 컨테이너 리소스 제한과 노드 상태도 같이 봐야 합니다. 쿠버네티스에서 갑자기 지연이 튀고 재시작이 반복된다면, 앱 문제만 보지 말고 K8s CrashLoopBackOff 10분 원인별 진단법 처럼 인프라 신호도 먼저 확인하는 게 빠릅니다.

1) PyTorch BERT를 ONNX로 내보내기

핵심은 동적 축을 제대로 지정해 입력 길이와 배치가 바뀌어도 재사용 가능한 ONNX를 만드는 것입니다.

아래 예시는 AutoModelForSequenceClassification 기준입니다.

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_id = "bert-base-uncased"
num_labels = 2

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=num_labels)
model.eval()

text = "onnx runtime optimization for bert"
inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)

onnx_path = "bert_cls.onnx"

dynamic_axes = {
    "input_ids": {0: "batch", 1: "seq"},
    "attention_mask": {0: "batch", 1: "seq"},
    "logits": {0: "batch"},
}

with torch.no_grad():
    torch.onnx.export(
        model,
        (inputs["input_ids"], inputs["attention_mask"]),
        onnx_path,
        input_names=["input_ids", "attention_mask"],
        output_names=["logits"],
        dynamic_axes=dynamic_axes,
        opset_version=17,
        do_constant_folding=True,
    )

print("exported:", onnx_path)

내보내기에서 자주 터지는 포인트

  • opset_version 가 너무 낮으면 연산이 깨질 수 있습니다. 최근 모델은 17 정도가 무난합니다.
  • token_type_ids 가 필요한 모델도 있습니다. 그 경우 입력에 추가하고 dynamic_axes 도 같이 지정하세요.
  • max_length 를 너무 크게 잡으면 벤치마크가 비현실적으로 느려집니다. 실제 트래픽 분포에 맞추는 게 중요합니다.

2) ONNX Runtime 세션 설정으로 기본 성능 확보

ONNX Runtime은 세션 옵션과 그래프 최적화 레벨만으로도 꽤 차이가 납니다.

import onnxruntime as ort

sess_options = ort.SessionOptions()
# 가장 강한 그래프 최적화
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

# 스레드 튜닝: 환경에 맞게 조절
sess_options.intra_op_num_threads = 4
sess_options.inter_op_num_threads = 1

providers = ["CPUExecutionProvider"]

session = ort.InferenceSession("bert_cls.onnx", sess_options=sess_options, providers=providers)
print(session.get_providers())

스레드 튜닝 팁

  • 단건 요청이 많으면 intra_op_num_threads 를 코어 수만큼 무작정 올리는 게 오히려 역효과일 수 있습니다.
  • 배치 추론이나 워커 수가 적으면 스레드를 늘리는 편이 유리합니다.
  • 컨테이너 CPU limit이 2 인데 스레드를 16 으로 두면 문맥 전환 비용만 커집니다.

3) 동적 양자화로 3배 경량화의 핵심 만들기

CPU에서 BERT를 크게 빠르게 만드는 가장 흔한 방법은 동적 양자화입니다.

  • 가중치를 INT8로 줄이고
  • 매트멀 계열 연산에서 메모리 대역폭과 캐시 효율을 개선합니다.

ONNX Runtime의 양자화 도구를 사용합니다.

pip install -U onnxruntime-tools

양자화 코드:

from onnxruntime.quantization import quantize_dynamic, QuantType

fp32_path = "bert_cls.onnx"
int8_path = "bert_cls.int8.onnx"

quantize_dynamic(
    model_input=fp32_path,
    model_output=int8_path,
    weight_type=QuantType.QInt8,
)

print("quantized:", int8_path)

양자화 후 체크할 것

  1. 모델 크기: 파일 크기가 유의미하게 줄어야 합니다.
  2. 정확도: 분류/랭킹은 대체로 영향이 작지만, 태스크에 따라 민감할 수 있습니다.
  3. 성능: CPU 종류에 따라 이득이 다릅니다. AVX512 VNNI 같은 지원이 있으면 더 유리합니다.

4) 추론 코드: 토크나이저 포함 오버헤드 분리

추론 성능을 제대로 측정하려면 “모델”과 “토크나이저” 시간을 분리해야 합니다. 토크나이저가 병목이면 모델만 빨라져도 체감이 적습니다.

import time
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def bench(session, text_list, max_length=128, warmup=10, iters=100):
    # tokenize
    t0 = time.perf_counter()
    enc = tokenizer(
        text_list,
        return_tensors="np",
        padding="max_length",
        truncation=True,
        max_length=max_length,
    )
    t1 = time.perf_counter()

    ort_inputs = {
        "input_ids": enc["input_ids"].astype(np.int64),
        "attention_mask": enc["attention_mask"].astype(np.int64),
    }

    # warmup
    for _ in range(warmup):
        _ = session.run(["logits"], ort_inputs)

    t2 = time.perf_counter()
    for _ in range(iters):
        _ = session.run(["logits"], ort_inputs)
    t3 = time.perf_counter()

    return {
        "tokenize_ms": (t1 - t0) * 1000,
        "warmup_ms": (t2 - t1) * 1000,
        "infer_avg_ms": ((t3 - t2) / iters) * 1000,
    }

sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

fp32_sess = ort.InferenceSession("bert_cls.onnx", sess_options=sess_options, providers=["CPUExecutionProvider"])
int8_sess = ort.InferenceSession("bert_cls.int8.onnx", sess_options=sess_options, providers=["CPUExecutionProvider"])

texts = ["this is a test sentence"] * 8  # 배치 8 예시

print("fp32:", bench(fp32_sess, texts))
print("int8:", bench(int8_sess, texts))

여기서 관찰하고 싶은 것은 infer_avg_ms 입니다. 토크나이즈 시간이 더 크면, 다음 중 하나를 고려해야 합니다.

  • tokenizers 의 fast 토크나이저 사용 여부
  • 토크나이즈를 별도 서비스로 분리하거나 캐시
  • 입력 길이 제한 정책

5) I/O Binding으로 메모리 복사 줄이기(고급)

단건 요청이 많고 지연 시간이 민감하면, 작은 오버헤드도 누적됩니다. ONNX Runtime은 IOBinding 으로 입력/출력 버퍼를 바인딩해 불필요한 복사를 줄일 수 있습니다.

CPU에서도 효과가 있을 수 있지만, 특히 GPU에서 더 체감이 큽니다. 아래는 CPU에서의 예시 흐름입니다.

import numpy as np
import onnxruntime as ort

session = ort.InferenceSession("bert_cls.int8.onnx", providers=["CPUExecutionProvider"])

# 더미 입력
batch, seq = 8, 128
input_ids = np.zeros((batch, seq), dtype=np.int64)
attention_mask = np.ones((batch, seq), dtype=np.int64)

io = session.io_binding()
io.bind_cpu_input("input_ids", input_ids)
io.bind_cpu_input("attention_mask", attention_mask)

# 출력은 이름만 바인딩하고 런타임이 버퍼를 채우게 할 수도 있음
io.bind_output("logits")

session.run_with_iobinding(io)
outputs = io.copy_outputs_to_cpu()
logits = outputs[0]
print(logits.shape)

서비스가 고QPS로 올라가면, 이런 미세 최적화가 tail latency를 줄이는 데 도움이 됩니다.

6) 정확도 회귀 테스트: 최소한의 안전장치

양자화나 그래프 최적화는 대개 안전하지만, 운영에서 가장 위험한 건 “조용히 성능이 좋아졌는데 결과가 달라진” 상황입니다.

간단한 회귀 체크를 권장합니다.

  • 샘플 데이터 N 개를 뽑아 PyTorch 출력과 ONNX 출력의
    • argmax 일치율
    • 로짓의 평균 절대 오차
    • AUC, F1 같은 태스크 지표 를 비교합니다.

예시 코드(로짓 비교):

import numpy as np
import torch
import onnxruntime as ort
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_id = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_id)
pt_model = AutoModelForSequenceClassification.from_pretrained(model_id)
pt_model.eval()

ort_sess = ort.InferenceSession("bert_cls.int8.onnx", providers=["CPUExecutionProvider"])

texts = [
    "i love this product",
    "this is terrible",
    "it is okay",
]
enc = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128)

with torch.no_grad():
    pt_logits = pt_model(enc["input_ids"], enc["attention_mask"]).logits.cpu().numpy()

ort_inputs = {
    "input_ids": enc["input_ids"].cpu().numpy().astype(np.int64),
    "attention_mask": enc["attention_mask"].cpu().numpy().astype(np.int64),
}
ort_logits = ort_sess.run(["logits"], ort_inputs)[0]

mae = np.mean(np.abs(pt_logits - ort_logits))
pt_pred = pt_logits.argmax(axis=1)
ort_pred = ort_logits.argmax(axis=1)
acc = (pt_pred == ort_pred).mean()

print("mae:", float(mae))
print("argmax_match:", float(acc))

운영 기준으로는 argmax_match 가 충분히 높고, 오프라인 지표가 허용 범위 내인지 확인하면 됩니다.

7) 운영 튜닝 체크리스트

7.1 컨테이너 환경에서 성능이 갑자기 떨어질 때

  • CPU throttling 여부 확인
  • 워커 수와 intra_op_num_threads 충돌 확인
  • 노드 간 성능 편차(서로 다른 CPU 세대 혼재) 확인

오토스케일링 환경에서는 노드가 늘지 않아 트래픽이 몰리면 “모델이 느려졌다”로 오해하기 쉽습니다. EKS 환경이라면 Karpenter 도입 후 EKS 노드가 안 늘 때 해결법 도 같이 참고하면 원인 분리가 빨라집니다.

7.2 모델 파일 로딩과 콜드 스타트

  • ONNX 세션 생성은 비용이 큽니다. 프로세스 시작 시 1회 생성 후 재사용하세요.
  • 서버리스나 스케일 아웃이 잦으면 콜드 스타트가 지연의 대부분일 수 있습니다.
  • 모델 파일이 크면 디스크 I/O도 병목이 됩니다. INT8로 줄이는 것만으로도 콜드 스타트가 개선됩니다.

7.3 입력 정책이 곧 비용 정책

  • max_length 를 무작정 512 로 두면 대부분의 요청에서 낭비가 발생합니다.
  • 실제 분포를 보고 128 또는 256 같은 상한을 두고, 긴 문서는 요약/분할 같은 전처리를 고려하세요.

8) 기대 가능한 결과 예시(현실적인 해석)

일반적으로 다음 조합이 가장 비용 대비 효과가 좋습니다.

  • ONNX 변환 + ORT_ENABLE_ALL
  • 동적 양자화(INT8)
  • 스레드 튜닝(컨테이너 CPU limit에 맞춤)

이 조합으로 모델 파일 크기 3분의 1 수준, CPU 추론 지연 2배에서 4배 개선은 흔히 관측됩니다. 다만 토크나이저 시간이 큰 워크로드에서는 end-to-end 개선폭이 줄어들 수 있으니, 반드시 “모델 추론”과 “전체 요청”을 분리해 측정하세요.

마무리: 가장 중요한 건 측정 가능한 파이프라인

BERT 최적화는 기법 자체보다도, 다음을 갖추는 순간부터 안정적으로 빨라집니다.

  • 입력 길이/배치/동시성 조건이 고정된 벤치마크
  • PyTorch 대비 ONNX 결과 회귀 테스트
  • 세션 옵션과 스레드 정책을 코드로 고정
  • 양자화 산출물(ONNX 파일)을 빌드 아티팩트로 관리

위 흐름대로 적용하면 “그냥 ONNX로 바꿨더니 조금 빨라짐” 수준을 넘어, 비용과 지연 시간을 체계적으로 줄일 수 있습니다.