- Published on
PyTorch 2 Export+TensorRT로 LLM FP8 압축
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서빙 비용을 낮추면서도 LLM 지연시간을 줄이는 가장 강력한 조합 중 하나가 PyTorch 2 export로 그래프를 안정적으로 고정한 뒤, TensorRT-LLM에서 FP8로 압축해 엔진을 빌드하는 흐름입니다.
다만 FP8은 단순히 dtype만 바꾸는 문제가 아니라, 캘리브레이션(스케일 추정), GEMM 커널 선택, KV cache dtype, 정확도 저하 모니터링까지 함께 설계해야 “빠른데 품질이 망가지지 않는” 결과가 나옵니다. 이 글에서는 다음을 목표로 합니다.
- PyTorch 2의
torch.export를 왜 쓰는지,torch.compile과의 역할 분리 - TensorRT-LLM에서 FP8 엔진을 만드는 실전 단계
- 정확도/속도/메모리 트레이드오프와 실패 패턴
- 최소 재현 가능한 코드 스니펫과 체크리스트
운영 관점에서 추적 가능성을 함께 챙기는 것을 권합니다. 성능 튜닝은 재현성이 핵심이라, 실험 파이프라인에 트레이싱을 붙여두면 회귀를 빨리 잡습니다. 관련해서는 OpenTelemetry로 MSA 분산 트랜잭션 추적 실전도 같이 참고하면 좋습니다.
전체 파이프라인 한 장 요약
FP8 압축을 “모델 파일을 FP8로 저장”하는 것으로 오해하는 경우가 많습니다. 실무에서는 아래처럼 나뉩니다.
- 모델 준비: Hugging Face 체크포인트, 또는 자체 학습 모델
- 그래프 고정: PyTorch 2
torch.export로 동적 shape 정책을 포함한 안정적 IR 생성 - 엔진 빌드: TensorRT-LLM 빌더로 FP8(또는 혼합) 정밀도, 플러그인, KV cache 설정
- 검증: 퍼플렉서티, 샘플 출력, 회귀 테스트, latency 및 VRAM 측정
- 서빙: Triton 또는 자체 gRPC/HTTP 서버 + 배치/스트리밍 정책
여기서 export는 “PyTorch 그래프를 깨끗하게 만들고, 빌드가 매번 같은 결과를 내도록 하는 장치”에 가깝고, 실제 FP8 압축은 TensorRT 쪽에서 일어납니다.
PyTorch 2에서 torch.export를 쓰는 이유
torch.compile은 런타임 최적화에 강하지만, 서빙용 엔진 빌드에서는 다음 문제가 생길 수 있습니다.
- 런타임에 shape 변동이 잦으면 그래프가 자주 깨지거나 재컴파일이 발생
- AOT로 내보낼 때 연산이 분해되거나(Decomposition) 예상치 못한 패턴이 생김
- 엔진 빌드 입력으로 “정확히 어떤 그래프가 나오는지” 통제하기 어려움
torch.export는 명시적으로 입력 스펙과 동적 차원을 선언해, “이 범위 안에서는 이 그래프”라는 계약을 만들 수 있습니다. TensorRT로 넘기기 전에 그래프를 고정해두면, 빌드 재현성과 디버깅 난이도가 크게 내려갑니다.
torch.export 최소 예시
아래 코드는 개념 예시입니다. 실제 LLM은 입력이 input_ids, attention_mask, position_ids, past_key_values 등으로 더 복잡하지만, 핵심은 Dim으로 동적 축을 선언하는 방식입니다.
import torch
from torch.export import export, Dim
class Toy(torch.nn.Module):
def forward(self, x):
return x @ x.transpose(-1, -2)
m = Toy().eval().cuda()
B = Dim("batch", min=1, max=8)
S = Dim("seq", min=1, max=4096)
example = (torch.randn(1, 16, 16, device="cuda"),)
# 실제로는 입력 텐서 shape에 맞게 dynamic_shapes를 작성
ep = export(
m,
args=example,
dynamic_shapes={"x": {0: B}},
)
print(ep.graph_module)
LLM에 적용할 때는 보통 batch, seq_len을 동적 축으로 선언하고, 엔진을 여러 개로 쪼개는 전략을 함께 씁니다.
- 프리필(prefill)용 엔진: 긴
seq_len처리에 최적화 - 디코드(decode)용 엔진:
seq_len이 작고 반복 호출이 많으므로 latency 최적화
FP8이 “효과가 큰” 이유와 전제 조건
FP8은 FP16 대비 메모리 사용량을 절반 수준으로 줄이면서도, 최근 GPU에서 GEMM 처리량이 크게 올라갑니다. 특히 LLM은 대부분이 행렬곱이므로 효과가 큽니다.
다만 전제 조건이 있습니다.
- 하드웨어: FP8 텐서 코어를 제대로 지원하는 GPU 세대 필요
- 소프트웨어: TensorRT 버전, TensorRT-LLM 버전, CUDA 드라이버 조합 호환
- 캘리브레이션: 스케일링 정책이 맞지 않으면 품질이 급격히 무너질 수 있음
실무적으로는 “전층 FP8”보다 FP8 + 일부 민감 레이어 FP16 유지 같은 혼합 전략이 안정적입니다.
TensorRT-LLM에서 FP8 엔진 빌드 흐름
TensorRT-LLM은 LLM에 특화된 최적화(Attention 플러그인, KV cache 최적화, fused kernel)를 제공합니다. FP8 압축은 보통 빌드 시점에 설정합니다.
아래는 개념적인 빌드 커맨드 예시입니다. 실제 옵션 이름은 버전에 따라 달라질 수 있으니, 사용 중인 TensorRT-LLM 릴리스의 CLI 도움말을 기준으로 맞추세요.
python -m tensorrt_llm.commands.build \
--model_dir ./hf_model \
--output_dir ./trt_engines \
--dtype fp16 \
--quantization fp8 \
--max_batch_size 8 \
--max_input_len 4096 \
--max_output_len 512 \
--use_paged_kv_cache enable
핵심은 다음 3가지입니다.
- 정밀도 정책:
--quantization fp8또는 유사 옵션 - shape 상한:
max_batch_size,max_input_len,max_output_len - KV cache 전략: paged KV cache는 긴 컨텍스트에서 메모리 파편화를 줄이는 데 유리
FP8 캘리브레이션(스케일) 준비
FP8은 보통 레이어별 또는 채널별 스케일이 필요합니다. 스케일을 얻는 방법은 크게 2가지입니다.
- 사전 정의된 규칙 기반 스케일링(간단하지만 품질 리스크)
- 대표 데이터셋으로 캘리브레이션(권장)
캘리브레이션용 데이터는 “실제 트래픽 분포”를 닮아야 합니다. 예를 들어 코드 생성 모델이면 코드 토큰이 충분히 포함돼야 하고, 한국어 질의가 많다면 한국어 프롬프트가 포함돼야 합니다.
간단한 캘리브레이션 샘플 생성 예시는 아래처럼 구성할 수 있습니다.
from datasets import load_dataset
from transformers import AutoTokenizer
model_id = "your-llm"
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
ds = load_dataset("json", data_files="calib_prompts.jsonl")
def to_ids(ex):
t = tok(ex["prompt"], truncation=True, max_length=4096)
return {"input_ids": t["input_ids"]}
calib = ds["train"].map(to_ids, remove_columns=ds["train"].column_names)
# calib을 TensorRT-LLM 캘리브레이션 입력으로 사용
print(len(calib), calib[0]["input_ids"][:16])
정확도 검증: “좋아 보이는 샘플”만으로는 부족
FP8 적용 후 흔히 발생하는 함정은 다음입니다.
- 샘플 몇 개는 그럴듯한데, 특정 도메인에서만 붕괴
- 긴 컨텍스트에서만 답변 품질이 급락
- 온도, top-p 등 샘플링 하이퍼파라미터에 따라 열화가 과장 또는 은폐
권장 검증 루틴은 아래처럼 3단계입니다.
- 정량 지표: 퍼플렉서티 또는 태스크별 자동 평가
- 회귀 테스트: 고정 프롬프트 세트로 출력 diff를 기록
- 운영 지표: latency P50/P95, 토큰당 지연, OOM/리트라이율
회귀 테스트는 “엔진 빌드 옵션이 바뀌었는데도 결과가 같은지”를 빠르게 확인합니다.
import json
import hashlib
def stable_hash(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:16]
# outputs: {"prompt_id": ..., "text": ...}
with open("outputs_fp16.jsonl") as f1, open("outputs_fp8.jsonl") as f2:
a = [json.loads(x) for x in f1]
b = [json.loads(x) for x in f2]
pairs = {x["prompt_id"]: x["text"] for x in a}
for x in b:
pid = x["prompt_id"]
if pid in pairs:
if stable_hash(pairs[pid]) != stable_hash(x["text"]):
print("diff", pid)
완전 일치가 목표는 아니지만, 특정 프롬프트에서만 큰 흔들림이 있으면 “민감 레이어 FP16 유지” 같은 대응을 고려해야 합니다.
성능 튜닝 포인트: FP8만으로 끝나지 않는다
FP8을 적용해도, 실제 TPS는 다음 요소에 의해 크게 좌우됩니다.
1) 프리필과 디코드 분리
- 프리필: 긴 시퀀스 한 번 처리, 대역폭과 fused attention 영향 큼
- 디코드: 작은 시퀀스 반복, 커널 런치 오버헤드와 KV cache 접근 패턴이 중요
엔진을 분리하면 각각의 최적화를 극단적으로 걸 수 있습니다.
2) KV cache dtype과 메모리 정책
KV cache를 FP8로 내리면 메모리는 줄지만 품질이 흔들릴 수 있습니다. 보수적으로는
- 가중치: FP8
- 활성/계산: FP16 또는 BF16
- KV cache: FP16 유지
같은 구성이 안정적인 출발점입니다.
3) 배치/스트리밍 정책
서빙에서는 동시성 때문에 배치가 커지지만, 디코드는 토큰 단위 스트리밍이 많아 “마이크로 배치”가 됩니다. 동적 배치 스케줄링이 필요하고, 이때 빌드된 엔진의 max_batch_size와 실제 트래픽 분포가 맞지 않으면 성능이 오히려 떨어집니다.
빌드 및 배포 파이프라인을 자주 돌리게 된다면, 컨테이너 빌드 최적화도 같이 챙기면 좋습니다. 예를 들어 Docker 빌드 느림? BuildKit 캐시·레이어 최적화 12처럼 레이어 캐시를 정리하면 실험 속도가 체감됩니다.
실패 패턴과 디버깅 체크리스트
패턴 A: 엔진 빌드는 되는데 출력이 이상함
- 캘리브레이션 데이터가 너무 짧거나 편향됨
- 특정 레이어가 FP8에 민감(입력 분포가 극단적)
- 로터리 포지셔널 임베딩, RMSNorm 등 주변 연산이 예상과 다르게 변환됨
대응:
- 캘리브레이션 프롬프트를 실제 트래픽에서 샘플링
- 민감 레이어를 FP16으로 강제(가능한 옵션이 있다면)
- 프리필과 디코드 엔진을 분리해 문제 구간을 좁힘
패턴 B: 특정 입력 길이에서만 OOM
max_input_len을 크게 잡아 workspace가 폭증- paged KV cache 비활성화로 메모리 파편화
대응:
- 엔진을 길이 구간별로 여러 개 생성
- paged KV cache 또는 메모리 풀 옵션 활성화
패턴 C: CI에서 재현이 안 되고 로컬만 빠름
- 드라이버, CUDA, TensorRT 버전 불일치
- 빌드 캐시 미사용으로 매번 엔진 재생성
대응:
- 빌드 환경을 컨테이너로 고정
- CI 캐시 전략 점검(특히 대용량 아티팩트)
Node 기반 툴링이 섞여 있다면 캐시가 의도대로 동작하지 않는 경우가 많습니다. GitHub Actions에서 node_modules 캐시가 안 먹힐 때 같은 패턴을 참고해, 엔진 아티팩트 캐시도 같은 방식으로 점검하세요.
실전 예시: export 스펙과 엔진 빌드 입력 정리
LLM의 입력은 보통 아래처럼 구성됩니다.
input_ids:int32attention_mask:int32또는boolposition_ids:int32past_key_values: 레이어 수만큼의 KV 텐서
torch.export에서 중요한 것은 “동적 축을 어디까지 허용할지”입니다. 예를 들어
batch:1..8seq_len: 프리필은1..4096, 디코드는1..1
처럼 엔진을 분리하면, 디코드 엔진은 매우 공격적으로 최적화할 수 있습니다.
아래는 입력 스펙을 구조화해 두는 예시입니다.
from dataclasses import dataclass
@dataclass
class EngineSpec:
name: str
max_batch: int
max_input_len: int
max_output_len: int
quant: str # "fp8" or "none"
prefill = EngineSpec(
name="prefill",
max_batch=8,
max_input_len=4096,
max_output_len=1,
quant="fp8",
)
decode = EngineSpec(
name="decode",
max_batch=32,
max_input_len=1,
max_output_len=512,
quant="fp8",
)
print(prefill, decode)
이런 식으로 “엔진 스펙을 코드로 관리”하면, 실험이 늘어날수록 옵션 난립을 막을 수 있습니다.
운영 관점 권장사항
- FP8은 성능 이득이 큰 대신, 버전 호환성과 캘리브레이션 품질에 민감합니다. 엔진 빌드 로그, 커밋 해시, 드라이버 버전, GPU 모델을 함께 기록하세요.
- 성능 회귀를 잡기 위해 토큰당 지연(
ms/token), 프리필 시간, 디코드 루프 시간을 분리해서 메트릭으로 내보내는 것을 권장합니다. - 실험 파이프라인이 길어지면 “빌드 최적화”가 곧 연구 속도입니다. 컨테이너 레이어와 캐시를 정리해두면 FP8 실험 반복이 훨씬 빨라집니다.
마무리
PyTorch 2 export는 LLM 그래프를 안정적으로 고정해 “엔진 빌드의 재현성”을 확보하는 도구이고, TensorRT-LLM FP8은 그 위에서 “실제 서빙 성능과 VRAM”을 크게 개선하는 레버입니다.
성공 확률을 높이는 핵심은
- 프리필과 디코드를 분리해 엔진을 설계하고
- 캘리브레이션 데이터를 트래픽에 가깝게 만들고
- FP8 적용 범위를 혼합 전략으로 조절하며
- 정량 평가와 회귀 테스트로 품질을 감시
하는 것입니다.
원하면 사용 중인 모델 계열(예: Llama 계열, Qwen 계열), 목표 GPU, 목표 컨텍스트 길이, 서빙 형태(스트리밍 여부)를 기준으로 엔진 스펙 템플릿과 캘리브레이션 프롬프트 구성 가이드까지 더 구체화해 드릴 수 있습니다.