- Published on
로컬 LLM 속도 2배 - KV 캐시·FlashAttention2
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
로컬에서 LLM을 돌리다 보면 “GPU는 80%도 안 쓰는데 토큰이 안 나온다”, “프롬프트가 길어지면 갑자기 느려진다” 같은 현상을 자주 겪습니다. 대부분의 경우 원인은 모델이 아니라 어텐션 계산과 메모리 이동에 있습니다. 특히 Transformers 기반 오토리그레시브 디코딩은 토큰을 한 개씩 생성하면서 같은 계산을 반복하기 때문에, 두 가지 레버가 성능을 크게 좌우합니다.
- KV 캐시(Key/Value cache): 이전 토큰들의 K,V를 저장해 다음 토큰에서 재사용
- FlashAttention2: 어텐션을 메모리-효율적으로 커널 수준에서 재구성해 대역폭 병목 완화
이 글은 “이론 설명”에서 끝내지 않고, 로컬 추론에서 실제로 체감되는 2배급 개선을 만들기 위한 설정, 코드, 검증 포인트를 중심으로 정리합니다.
왜 느려지는가: 프리필과 디코드의 병목 분리
LLM 추론은 크게 두 단계로 나뉩니다.
- Prefill(프롬프트 처리): 입력 시퀀스 전체를 한 번에 통과시키며 첫 토큰을 준비
- Decode(생성 루프): 이후 토큰을 1개씩 생성하며 반복
긴 프롬프트에서 느려지는 이유는 prefill 비용이 커지기 때문이고, 생성 속도가 느린 이유는 decode 단계에서 토큰마다 어텐션을 반복 계산하기 때문입니다.
어텐션의 계산 비용을 단순화해서 보면, 시퀀스 길이를 L, 헤드 차원을 D라 할 때:
- KV 캐시가 없으면 매 토큰마다 과거 전체에 대한 K,V를 다시 만들고, 어텐션도 다시 계산합니다.
- KV 캐시가 있으면 “이번 토큰의 Q”만 새로 만들고, 과거 K,V는 재사용합니다.
즉 decode 단계에서 KV 캐시는 거의 필수 최적화이고, FlashAttention2는 그 어텐션 연산 자체를 더 빠르게 수행하게 해줍니다.
KV 캐시: 토큰 생성 속도를 올리는 1순위
KV 캐시가 하는 일
오토리그레시브 디코딩에서 매 스텝마다 필요한 것은:
- 새 토큰의 Q,K,V
- 과거 토큰들의 K,V
여기서 과거 K,V를 매번 다시 계산하면 낭비가 큽니다. KV 캐시는 레이어별로 누적된 K,V를 저장해 두고, 다음 스텝에서 이어 붙여 사용합니다.
Transformers에서 KV 캐시 켜기
Hugging Face Transformers에서는 보통 use_cache=True가 핵심입니다. 다만 모델/버전/컴파일 옵션에 따라 기본값이 다를 수 있어 명시하는 편이 안전합니다.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "meta-llama/Llama-2-7b-hf" # 예시
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
)
prompt = "Explain KV cache in one paragraph."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=200,
do_sample=False,
use_cache=True,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
KV 캐시가 꺼지는 흔한 함정
아래 중 하나라도 걸리면 “캐시를 켰는데도 느린” 상황이 나옵니다.
- 학습 모드:
model.train()상태거나torch.no_grad()미사용 - 잘못된 디코딩 루프:
generate대신 직접 루프를 돌면서past_key_values를 전달하지 않음 - 특정 최적화와 충돌: 일부 환경에서
gradient_checkpointing옵션이 켜져 있으면 캐시가 비활성화되기도 함(추론에서는 끄는 게 정석) - 배치/패딩 처리 문제: 배치 내 길이 차이가 크면 캐시 이득이 줄고, 패딩이 많으면 메모리 낭비가 커짐
직접 디코딩 루프에서 past_key_values 사용 예
커스텀 로직(스트리밍, 토큰 단위 제어)을 위해 직접 루프를 돌릴 때는 반드시 past_key_values를 이어줘야 합니다.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "gpt2" # 데모용
tok = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id).to("cuda").eval()
prompt = "The key idea is"
input_ids = tok(prompt, return_tensors="pt").input_ids.to("cuda")
past = None
for _ in range(50):
with torch.inference_mode():
outputs = model(
input_ids=input_ids,
past_key_values=past,
use_cache=True,
)
logits = outputs.logits[:, -1, :]
next_id = torch.argmax(logits, dim=-1, keepdim=True)
past = outputs.past_key_values
input_ids = next_id
print("done")
이 패턴이 아니면 KV 캐시를 사실상 “안 쓰는” 디코딩이 됩니다.
FlashAttention2: 어텐션을 메모리 병목에서 구출
FlashAttention2가 빠른 이유
일반적인 어텐션 구현은 중간 행렬(특히 L x L 스코어)을 만들면서 메모리 트래픽이 폭증합니다. FlashAttention 계열은 이를 타일링하고, 필요한 값만 스트리밍하며, softmax를 안정적으로 계산해 HBM 대역폭 병목을 줄이고 커널 효율을 끌어올립니다.
특히 긴 컨텍스트에서 효과가 커지고, decode 단계에서도 헤드 수/차원/정밀도 조건이 맞으면 이득이 납니다.
Transformers에서 FlashAttention2 활성화
최근 Transformers는 attn_implementation 옵션으로 플러그인처럼 선택할 수 있습니다.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "mistralai/Mistral-7B-Instruct-v0.2" # 예시
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cuda",
attn_implementation="flash_attention_2",
)
inputs = tokenizer("Write a short note about attention kernels.", return_tensors="pt").to("cuda")
with torch.inference_mode():
y = model.generate(**inputs, max_new_tokens=128, use_cache=True)
print(tokenizer.decode(y[0], skip_special_tokens=True))
설치 체크
환경에 따라 flash-attn 패키지가 필요합니다.
pip install -U flash-attn --no-build-isolation
주의할 점은 CUDA, PyTorch, GPU 아키텍처, 컴파일 옵션에 민감하다는 것입니다. 설치가 실패하거나 런타임에서 폴백되면 성능 향상이 사라집니다.
FlashAttention2가 폴백되는 전형적인 경우
- dtype이 맞지 않음: 모델이
float32로 떠 있거나, 혼합 정밀이 깨짐 - GPU 아키텍처/드라이버 호환 문제
- 특정 모델 구조에서 지원되지 않는 어텐션 변형(예: 일부 마스킹/로프 구현)
이때는 속도가 “그대로”이거나 오히려 느려질 수 있으니, 반드시 실제로 FlashAttention 커널이 사용되는지 확인해야 합니다.
2배를 만드는 조합: KV 캐시 + FlashAttention2 + 추론 모드
KV 캐시와 FlashAttention2는 서로 대체 관계가 아니라 누적 효과가 납니다.
- KV 캐시: decode 단계에서 중복 계산 제거
- FlashAttention2: 남은 어텐션 연산을 더 빠른 커널로 처리
torch.inference_mode(): autograd 오버헤드 제거
추가로, 아래 설정도 체감에 영향을 많이 줍니다.
dtype 선택: float16 vs bfloat16
- Ampere 이후 GPU에서는
bfloat16이 안정성과 성능이 좋은 경우가 많습니다. - 모델 가중치가 어떤 dtype로 제공되는지, GPU가 어떤 연산을 잘하는지에 따라 달라집니다.
torch.compile은 상황을 보고
PyTorch torch.compile은 모델과 환경에 따라 이득이 크기도 하지만, 초기 컴파일 비용이 있고 일부 조합에서는 역효과가 날 수 있습니다. “서빙 프로세스가 오래 떠 있는” 로컬 서버라면 고려할 가치가 있습니다.
import torch
model = torch.compile(model, mode="reduce-overhead")
컴파일 후에는 첫 요청이 느려질 수 있으니 워밍업을 넣는 게 좋습니다.
성능 측정: 토큰/초를 제대로 재는 법
최적화는 “느낌”이 아니라 수치로 해야 합니다. 최소한 아래 두 지표는 분리해서 보세요.
- prefill latency: 첫 토큰이 나오기까지 시간
- decode throughput: 생성 토큰/초
간단한 측정 코드 예시는 다음과 같습니다.
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "gpt2"
tok = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id).to("cuda").eval()
prompt = "Explain why KV cache matters for autoregressive decoding. " * 50
inputs = tok(prompt, return_tensors="pt").to("cuda")
# 워밍업
with torch.inference_mode():
_ = model.generate(**inputs, max_new_tokens=8, use_cache=True)
torch.cuda.synchronize()
t0 = time.time()
with torch.inference_mode():
out = model.generate(**inputs, max_new_tokens=128, use_cache=True, do_sample=False)
torch.cuda.synchronize()
t1 = time.time()
new_tokens = out.shape[-1] - inputs["input_ids"].shape[-1]
print("seconds:", t1 - t0)
print("new_tokens:", new_tokens)
print("tok/s:", new_tokens / (t1 - t0))
여기서 torch.cuda.synchronize()는 GPU 비동기 실행 때문에 측정이 흔들리는 것을 줄여줍니다.
로컬 서빙 관점: 캐시가 “메모리 폭주”로 이어질 때
KV 캐시는 속도를 주지만, 대가로 GPU 메모리를 먹습니다. 특히 배치가 커지거나 동시 요청이 늘면 KV 캐시가 누적되어 OOM이 납니다.
이 문제는 로컬 서버에서도 흔하고, 에이전트/툴 호출이 붙으면 더 심해집니다. 메모리 관리 관점에서는 아래 글이 같이 도움이 됩니다.
실전 대응은 다음이 핵심입니다.
- 세션별 최대 컨텍스트 길이 제한
- 동시성 제한(큐잉)
- 필요 시 프롬프트 요약/압축
- 긴 대화는 RAG로 외부화
운영 체크리스트: “켰는데 왜 안 빨라요”를 끝내기
아래 항목을 순서대로 점검하면 원인을 빨리 좁힐 수 있습니다.
- 모델이 eval 모드인지:
model.eval() - 추론 컨텍스트인지:
torch.inference_mode() - KV 캐시가 실제로 쓰이는지: 직접 루프면
past_key_values전달 확인 - FlashAttention2가 실제로 활성화됐는지:
attn_implementation지정 및 런타임 폴백 여부 확인 - dtype이 맞는지:
float16또는bfloat16로 일관되게 유지 - 측정 방법이 올바른지: 워밍업, 동기화, prefill과 decode 분리
추가로, 외부 API를 섞어 쓰는 하이브리드 구성이라면 “재시도 폭주로 전체 지연이 늘어나는” 문제가 성능 체감에 큰 영향을 줍니다. 로컬 LLM이 빨라져도 외부 호출이 병목이면 전체 체감은 그대로일 수 있습니다.
결론: 2배는 ‘특별한 트릭’이 아니라 기본기
Transformers 로컬 추론을 빠르게 만드는 가장 확실한 루트는 다음 조합입니다.
- decode 단계 중복 계산을 제거하는 KV 캐시
- 어텐션 커널 효율을 끌어올리는 FlashAttention2
- 측정과 운영에서 실수를 줄이는 추론 모드, dtype, 워밍업, 동시성 제어
이 셋을 제대로 적용하면, 동일 GPU에서도 토큰/초가 눈에 띄게 올라가고(특히 긴 컨텍스트에서), “프롬프트가 길어지면 급격히 느려지는” 문제도 완화됩니다. 다음 단계로는 배치 전략, speculative decoding, 모델 양자화까지 확장할 수 있지만, 그 전에 KV 캐시와 FlashAttention2부터 확실히 재현 가능하게 잡아두는 것이 가장 높은 ROI입니다.