- Published on
Transformers 로컬 LLM KV 캐시로 TPS 2배 튜닝
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
로컬 LLM을 서빙하다 보면 GPU 사용률은 높게 찍히는데도 TPS가 기대만큼 안 나오는 경우가 많습니다. 특히 Transformers로 직접 추론 루프를 구성했을 때는 KV 캐시(Key/Value cache) 를 어떻게 쓰느냐에 따라 처리량이 크게 갈립니다.
이 글은 “왜 KV 캐시가 TPS를 좌우하는지”를 수식 수준까지 파고들기보다는, 실무에서 TPS를 2배에 가깝게 끌어올리는 데 필요한 설정/코드/측정 방법을 중심으로 정리합니다. (환경은 로컬 GPU 단일 노드 기준)
아래 내용은 “단일 요청 지연시간” 최적화가 아니라, 서빙 관점에서 TPS(초당 토큰 생성량)와 안정적인 처리량을 올리는 방향에 초점을 둡니다.
KV 캐시가 TPS를 올리는 원리(서빙 관점)
오토리그레시브(autoregressive) 디코딩에서 토큰을 1개 생성할 때마다 모델은 다음을 반복합니다.
- 입력 토큰 전체(프롬프트 + 지금까지 생성된 토큰)를 기반으로 attention 계산
- 다음 토큰 logits 산출
여기서 attention의 K/V(키/값) 텐서는 과거 토큰들에 대해 반복적으로 재사용됩니다. KV 캐시를 쓰지 않으면 매 스텝마다 과거 토큰들에 대한 K/V를 다시 계산하게 되어, 디코드 스텝이 길어질수록 비용이 급격히 증가합니다.
정리하면:
- 프리필(prefill): 프롬프트 길이만큼 한 번에 처리(상대적으로 큰 덩어리)
- 디코드(decode): 토큰 1개씩 반복 생성(매 스텝이 매우 빈번)
KV 캐시는 특히 디코드 구간의 반복 비용을 크게 줄여 TPS를 올리는 핵심 장치입니다.
“KV 캐시 켰는데 TPS가 안 오른다”의 흔한 원인
다음 중 하나라도 걸리면 캐시가 있어도 체감 TPS가 잘 안 오릅니다.
use_cache가 꺼져 있음(혹은 설정이 코드 경로에서 무시됨)generate가 매번 프롬프트부터 다시 태우는 구조(세션 캐시 미사용)- 배치/패딩 전략이 비효율적이라 캐시 이득이 상쇄됨
- 샘플링 설정이 GPU를 덜 쓰게 만들어 병목이 다른 곳에 생김
- 측정 방식이 잘못되어 프리필 비용이 TPS를 왜곡함
이 글은 1~5를 순서대로 정리합니다.
기본 전제: Transformers에서 KV 캐시가 어떻게 동작하나
Transformers 모델의 forward는 보통 past_key_values를 입력으로 받아 다음을 반환합니다.
- logits
past_key_values(업데이트된 캐시)
그리고 generate는 내부적으로 이 캐시를 활용해 디코드 루프를 돌립니다.
다만 서빙에서 TPS를 2배 수준으로 올리려면 단순히 generate 호출만으로는 부족한 경우가 많습니다. 이유는 다음과 같습니다.
generate는 범용 API라서, “요청 간 캐시 재사용(세션)”이나 “프리필/디코드 분리 측정” 같은 서빙 최적화가 제한적- 배치/패딩/동적 배치와 결합할 때 병목이 생기기 쉬움
그래서 보통은 generate를 쓰더라도, 최소한 캐시가 켜져 있는지 검증하고, 필요하면 직접 디코드 루프를 구성합니다.
1) 캐시가 진짜 켜져 있는지 확인하기
모델/설정/코드 경로에 따라 use_cache가 꺼져 있거나, 추론 시점에 무시되는 경우가 있습니다.
필수 체크리스트
model.config.use_cache가True인지- forward 호출에서
use_cache=True가 전달되는지 - 반환값에
past_key_values가 실제로 존재하는지
아래는 가장 단순한 검증 코드입니다.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "meta-llama/Llama-2-7b-hf" # 예시
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
)
model.eval()
# 캐시 설정 확인
model.config.use_cache = True
prompt = "Explain KV cache in one paragraph."
inputs = tok(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
out = model(**inputs, use_cache=True)
print("has past_key_values:", out.past_key_values is not None)
if out.past_key_values is not None:
# 레이어 수 확인
print("layers:", len(out.past_key_values))
여기서 past_key_values가 None이면, KV 캐시가 동작하지 않는 경로일 확률이 높습니다.
2) 프리필/디코드를 분리해서 측정해야 TPS가 보인다
서빙 TPS를 논할 때 가장 흔한 실수는 “전체 요청 시간을 토큰 수로 나누는 방식”으로만 판단하는 것입니다. 프롬프트가 길수록 프리필 비용이 커져서, 디코드 TPS 개선이 묻힙니다.
권장 측정 방식:
- 프리필 시간(프롬프트를 모델에 한 번 태우는 시간)
- 디코드 시간(생성 토큰 구간만)
- 디코드 TPS = 생성 토큰 수 / 디코드 시간
아래는 generate 기반으로도 대략 분리 측정하는 예시입니다. (엄밀히는 내부 루프가 섞이지만, 실무 튜닝에는 충분히 유용합니다)
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "meta-llama/Llama-2-7b-hf"
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
)
model.eval()
model.config.use_cache = True
def benchmark(prompt: str, max_new_tokens: int = 128):
inputs = tok(prompt, return_tensors="pt").to("cuda")
# 프리필만 측정: forward 1회
torch.cuda.synchronize()
t0 = time.perf_counter()
with torch.no_grad():
out = model(**inputs, use_cache=True)
torch.cuda.synchronize()
t1 = time.perf_counter()
# 디코드 측정: generate
torch.cuda.synchronize()
t2 = time.perf_counter()
with torch.no_grad():
gen = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
torch.cuda.synchronize()
t3 = time.perf_counter()
prefill_s = t1 - t0
total_s = t3 - t2
# 생성 토큰 수 계산(프롬프트 길이 제외)
prompt_len = inputs["input_ids"].shape[1]
gen_len = gen.shape[1]
new_tokens = gen_len - prompt_len
# 전체 generate 시간에는 프리필이 포함되므로, 대략적인 참고치로만 사용
approx_tps = new_tokens / total_s
return {
"prefill_s": prefill_s,
"generate_s": total_s,
"new_tokens": new_tokens,
"approx_tps": approx_tps,
}
print(benchmark("Write a short story about caching.", 128))
실제로 TPS를 2배로 끌어올렸는지 보려면, 최소한 “프롬프트 길이를 고정”하고 비교하거나, 더 나아가 직접 디코드 루프로 디코드만 분리 측정하는 것이 좋습니다.
3) 직접 디코드 루프로 캐시 효과를 확실히 얻기
서빙에서는 다음 패턴이 가장 단순하면서도 효과적입니다.
- 프리필 1회 수행해서
past_key_values확보 - 이후에는 매 스텝마다 “마지막 토큰 1개”만 넣고
past_key_values를 갱신
이 방식은 디코드 구간에서 입력 길이가 사실상 1로 고정되므로, 캐시 이득이 명확합니다.
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "meta-llama/Llama-2-7b-hf"
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
)
model.eval()
model.config.use_cache = True
@torch.no_grad()
def decode_with_kv_cache(prompt: str, max_new_tokens: int = 128):
inputs = tok(prompt, return_tensors="pt").to("cuda")
input_ids = inputs["input_ids"]
# 1) prefill
torch.cuda.synchronize()
t0 = time.perf_counter()
out = model(input_ids=input_ids, use_cache=True)
past = out.past_key_values
torch.cuda.synchronize()
t1 = time.perf_counter()
# 2) decode loop
generated = []
next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
torch.cuda.synchronize()
t2 = time.perf_counter()
for _ in range(max_new_tokens):
out = model(input_ids=next_token, past_key_values=past, use_cache=True)
past = out.past_key_values
next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
generated.append(next_token)
torch.cuda.synchronize()
t3 = time.perf_counter()
prefill_s = t1 - t0
decode_s = t3 - t2
tps = max_new_tokens / decode_s
gen_ids = torch.cat([input_ids] + generated, dim=1)
text = tok.decode(gen_ids[0], skip_special_tokens=True)
return {
"prefill_s": prefill_s,
"decode_s": decode_s,
"decode_tps": tps,
"text": text,
}
res = decode_with_kv_cache("Explain transformers KV cache briefly.", 128)
print({k: v for k, v in res.items() if k != "text"})
이 코드로 “캐시 사용”과 “캐시 미사용”을 비교하면 차이가 더 선명하게 나타납니다. 캐시 미사용 비교는 루프에서 past_key_values를 전달하지 않고, 매번 전체 시퀀스를 넣는 방식으로 구성하면 됩니다. (다만 그 방식은 OOM 위험이 커서 길이를 짧게 두고 비교하세요)
4) TPS 2배를 만드는 실전 튜닝 포인트
여기부터가 핵심입니다. KV 캐시 자체는 기본 기능이지만, 서빙 TPS는 주변 조건(배치, 패딩, 메모리, 커널 선택)에 의해 크게 흔들립니다.
4-1) torch.inference_mode()와 eval() 고정
추론에서 autograd 오버헤드를 확실히 제거해야 합니다.
model.eval()with torch.inference_mode():권장 (no_grad보다 더 강함)
model.eval()
with torch.inference_mode():
out = model(**inputs, use_cache=True)
4-2) dtype과 attention 구현 선택
가능하면 다음 조합이 TPS에 유리한 경우가 많습니다.
torch.float16또는bfloat16(GPU 지원에 따라)- PyTorch SDPA 설정(가능한 경우 flash 계열 커널 유도)
환경에 따라 다르지만, PyTorch 2.x에서는 다음 설정이 도움이 될 수 있습니다.
import torch
torch.backends.cuda.matmul.allow_tf32 = True
# Ampere 이상에서 TF32가 matmul 처리량을 올릴 수 있음
# SDPA 커널 선택은 버전/모델에 따라 효과가 다름
# flash 우선이 항상 정답은 아니니 측정 기반으로 선택
주의: 모델/드라이버/버전에 따라 flash attention 계열이 특정 shape에서 오히려 느려질 수 있습니다. 반드시 측정하세요.
4-3) 배치 전략: “프리필 배치”와 “디코드 배치”는 성격이 다르다
서빙에서 TPS를 올리는 가장 강력한 레버는 배치지만, KV 캐시가 들어오면 배치가 까다로워집니다.
- 프리필: 입력 길이가 제각각이라 패딩 낭비가 큼
- 디코드: 스텝 동기화가 필요하고, 요청마다 종료 시점이 달라 “ragged batch” 문제가 생김
실전 팁:
- 프롬프트 길이가 비슷한 요청끼리 묶는 length bucketing
- 디코드에서는 “동일 step에서 함께 디코드할 요청”만 묶는 동적 배치(dynamic batching)
- 너무 큰 배치는 오히려 latency를 악화시키고, GPU 메모리 압박으로 TPS가 떨어질 수 있음
이 지점에서 서빙 프레임워크(KServe, vLLM, TGI 등)가 강점을 갖습니다. 다만 이 글은 Transformers 로컬 튜닝이 주제이므로, 직접 구성한다면 “프리필 큐”와 “디코드 큐”를 분리하는 설계를 추천합니다.
콜드스타트와 운영 배포 관점은 아래 글이 함께 도움이 됩니다.
4-4) 패딩 최적화: left padding과 attention mask
디코드에서 마지막 토큰만 넣는 루프를 구성하면 패딩 이슈가 줄지만, generate를 배치로 돌릴 때는 패딩이 TPS를 갉아먹습니다.
- Causal LM은 보통 left padding이 유리한 경우가 많습니다(오른쪽 패딩은 불필요한 계산이 늘 수 있음)
- 토크나이저에
padding_side="left"를 설정
tok.padding_side = "left"
if tok.pad_token is None:
tok.pad_token = tok.eos_token
모델에 따라 pad 토큰 처리 방식이 달라서, 이 설정은 반드시 품질(출력 이상)과 함께 검증해야 합니다.
4-5) 캐시 타입과 메모리: OOM이 TPS를 망친다
KV 캐시는 메모리를 먹습니다. 배치가 커질수록, 프롬프트가 길수록, 생성 토큰이 길수록 캐시는 선형으로 증가합니다.
OOM이 나지 않더라도 다음이 발생할 수 있습니다.
- GPU 메모리 압박으로 인해 allocator 단편화
- 더 잦은 CPU-GPU 동기화
- 결과적으로 TPS 하락
실전에서는 다음을 같이 봅니다.
max_new_tokens상한- 프롬프트 길이 상한(또는 요약/압축)
- 동시 요청 수 상한
프롬프트 폭주나 토큰 폭탄을 막는 운영 패턴은 아래 글도 참고할 만합니다.
5) “세션 KV 캐시”로 체감 TPS를 더 올리기(대화형)
지금까지는 “요청 1개 안에서” KV 캐시를 쓰는 방법이었습니다. 하지만 챗봇처럼 같은 사용자가 연속 턴을 보내면, 매 턴마다 과거 대화를 프롬프트로 다시 넣는 비용이 커집니다.
이때는 세션 단위로 past_key_values를 저장해 다음 턴에 재사용하는 방식이 효과적입니다.
개념적으로는:
- 첫 턴: prefill해서
past_key_values저장 - 다음 턴: 새로 들어온 사용자 입력만 토크나이즈해서 캐시에 이어 붙임
주의할 점:
- 토큰 경계 관리(시스템 프롬프트, 역할 토큰 등)
- 세션별 캐시 메모리 상한 및 eviction 정책(LRU 등)
- 사용자별로 캐시가 섞이지 않도록 격리
간단한 스케치 코드는 아래와 같습니다.
from dataclasses import dataclass
import torch
@dataclass
class SessionState:
past_key_values: object | None
last_token: torch.Tensor | None
sessions: dict[str, SessionState] = {}
@torch.no_grad()
def chat_step(session_id: str, user_text: str, model, tok, max_new_tokens: int = 64):
state = sessions.get(session_id) or SessionState(None, None)
# 새 입력 토큰화
# 실제 서비스에서는 역할 토큰/템플릿을 포함해야 함
inp = tok(user_text, return_tensors="pt").to("cuda")
input_ids = inp["input_ids"]
if state.past_key_values is None:
out = model(input_ids=input_ids, use_cache=True)
else:
out = model(input_ids=input_ids, past_key_values=state.past_key_values, use_cache=True)
past = out.past_key_values
next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
# 간단 디코드
gen = []
for _ in range(max_new_tokens):
out = model(input_ids=next_token, past_key_values=past, use_cache=True)
past = out.past_key_values
next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
gen.append(next_token)
sessions[session_id] = SessionState(past, next_token)
gen_ids = torch.cat(gen, dim=1)
return tok.decode(gen_ids[0], skip_special_tokens=True)
이 방식은 “대화 히스토리를 매번 프롬프트로 재전송”하는 비용을 줄여, 사용자 체감 응답성과 서버 TPS 모두에 도움이 됩니다.
단, 세션 캐시는 운영 리스크도 큽니다. 세션이 늘어나면 KV 캐시가 GPU 메모리를 잠식하므로, TTL/LRU로 과감하게 버려야 합니다.
6) TPS 2배를 목표로 한 튜닝 체크리스트
아래 체크리스트 순서대로 적용하면, 대체로 “눈에 보이는 TPS 개선”이 나옵니다.
model.eval()+torch.inference_mode()적용use_cache=True및past_key_values반환 여부 확인- 프리필/디코드 분리 측정(프롬프트 길이 고정)
- 직접 디코드 루프로 캐시 효과를 확인(마지막 토큰 1개 입력)
- dtype 최적화(
fp16또는bf16), TF32 허용 여부 점검 - 배치 전략 수립(프리필 버킷팅 + 디코드 동적 배치)
- 패딩 최적화(left padding, pad 토큰)
- OOM 및 메모리 압박 방지(상한 설정, 세션 캐시 eviction)
7) 마무리: KV 캐시는 “디코드 비용”을 줄이는 장치다
KV 캐시로 TPS를 2배 올리는 핵심은 단순히 옵션을 켜는 게 아니라,
- 디코드 구간에서 “과거 토큰 재계산”을 없애고
- 배치/패딩/메모리로 캐시 이득이 상쇄되지 않게 만들며
- 측정에서 프리필 비용을 분리해 개선 효과를 정확히 보는 것
입니다.
만약 여기까지 했는데도 TPS가 잘 안 오르면, 다음 단계는 “Transformers 순정 루프”를 넘어 전용 서빙 엔진(vLLM 계열의 paged attention, TGI 등)으로 넘어가는 것을 검토할 타이밍입니다. 하지만 로컬에서 빠르게 실험하고, 제품화 전 튜닝 포인트를 잡는 데는 이 글의 방식만으로도 충분히 큰 개선을 만들 수 있습니다.