- Published on
Transformers 로컬 LLM OOM? KV 캐시 절감 5가지
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
로컬 GPU(예: 8GB~24GB)에서 transformers로 LLM 추론을 하다 보면, 모델 가중치를 다 올렸는데도 생성 길이가 조금만 길어져도 CUDA out of memory가 터지는 경우가 많습니다. 이때 진짜 범인은 KV 캐시(Key/Value cache) 인 경우가 흔합니다. 특히 배치가 커지거나 컨텍스트가 길어질수록 KV 캐시는 선형으로 증가해 메모리를 빠르게 잡아먹습니다.
이 글에서는 KV 캐시가 왜 OOM을 유발하는지 감을 잡을 수 있도록 대략적인 메모리 스케일을 설명하고, 당장 적용 가능한 KV 캐시 절감 5가지를 transformers 코드와 함께 정리합니다.
참고로 “OOM을 진단하고 원인을 쪼개는 방식”은 JVM 영역이지만 접근법은 비슷합니다. 메모리 병목을 단계적으로 확인하는 흐름은 Spring Boot OutOfMemoryError 덤프 분석·튜닝 7단계도 같이 보면 도움이 됩니다.
KV 캐시가 커지는 이유 (OOM의 핵심)
오토리그레시브 디코딩에서 다음 토큰을 생성할 때, 매 스텝마다 과거 토큰들의 어텐션을 다시 계산하면 너무 느립니다. 그래서 transformers는 기본적으로 use_cache=True로 각 레이어의 K/V 텐서를 저장해두고 재사용합니다.
KV 캐시 메모리는 대략 다음 요소에 비례합니다.
- 배치 크기
B - 레이어 수
L - 헤드 수
H - 헤드 차원
D - 누적 시퀀스 길이
S(프롬프트 길이 + 생성된 길이) - 데이터 타입 바이트 수 (예:
fp16는 2바이트)
직관적으로는 S가 늘어날수록 매 레이어마다 K와 V가 계속 쌓이는 구조라서, “프롬프트를 길게 넣고 길게 생성”하는 패턴에서 OOM이 잘 납니다.
OOM인지 KV 캐시인지 빠르게 구분하기
아래 증상이 있으면 KV 캐시 가능성이 큽니다.
- 짧은 생성은 되는데
max_new_tokens를 늘리면 특정 길이에서 갑자기 OOM - 배치
B를 1에서 2로만 올려도 급격히 메모리 부족 - 동일 모델인데 프롬프트 길이만 늘리면 OOM
간단히 GPU 메모리를 찍어보면 증가 패턴이 보입니다.
import torch
def report(tag=""):
torch.cuda.synchronize()
alloc = torch.cuda.memory_allocated() / 1024**2
reserved = torch.cuda.memory_reserved() / 1024**2
peak = torch.cuda.max_memory_allocated() / 1024**2
print(f"[{tag}] alloc={alloc:.1f}MiB reserved={reserved:.1f}MiB peak={peak:.1f}MiB")
# 사용 예
# report("before")
# ... generate ...
# report("after")
generate가 진행될수록 alloc이 꾸준히 증가하면 KV 캐시가 주 원인일 확률이 높습니다.
1) 생성 길이와 컨텍스트 길이를 강하게 제한하기
가장 확실하고, 대부분의 경우 가장 효과적입니다. KV 캐시는 S에 선형 비례하므로 아래 2가지를 먼저 조절하세요.
- 입력 컨텍스트 길이 제한:
max_length또는 토크나이저 트렁케이션 - 출력 생성 길이 제한:
max_new_tokens
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_id = "meta-llama/Llama-2-7b-chat-hf" # 예시
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
)
prompt = """아래 로그를 요약해줘: ..."""
inputs = tok(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048, # 컨텍스트 상한
).to("cuda")
out = model.generate(
**inputs,
max_new_tokens=256, # 생성 상한
do_sample=False,
use_cache=True,
)
print(tok.decode(out[0], skip_special_tokens=True))
실전 팁
- “프롬프트는 길게, 답변은 짧게”가 아니라면, 보통은
max_new_tokens가 OOM 트리거입니다. - RAG라면 top-k 문서 수를 줄이거나, 문서 chunk 크기를 줄여 입력 길이부터 통제하세요.
2) 배치 크기와 동시 요청 수를 줄이기 (마이크로 배칭 주의)
KV 캐시는 B에 선형 비례합니다. 로컬 서버를 띄워 동시 요청을 처리할 때, “배치로 묶어 처리하면 효율적”이라고 생각하고 마이크로 배칭을 넣었다가 OOM이 나는 경우가 많습니다.
- 동시 요청 수 제한(큐잉)
- 배치 크기
B제한 - 스트리밍 응답을 켜서 긴 생성이 겹치지 않게 운영
예를 들어 FastAPI나 간단한 서버에서 세마포어로 동시성을 제한하는 방식이 현실적입니다.
import asyncio
sem = asyncio.Semaphore(1) # GPU 1장이라면 1~2부터 시작
async def generate_with_limit(fn, *args, **kwargs):
async with sem:
return await fn(*args, **kwargs)
서버가 계속 재시작될 정도로 메모리 압박이 심하면, “원인 추적 흐름” 자체는 인프라/서비스 공통입니다. 운영 관점의 재시작 원인 추적은 systemd 서비스가 계속 재시작될 때 원인 추적법도 함께 참고할 만합니다.
3) KV 캐시 dtype을 낮추기 (가능하면 fp8/int8 계열)
가중치 양자화(4bit, 8bit)만으로는 KV 캐시가 그대로 fp16로 남아 OOM이 계속 나는 경우가 있습니다. 이때는 KV 캐시 자체의 dtype을 낮추는 옵션/구성을 우선 검토하세요.
3-1) torch_dtype=bfloat16 또는 float16로 확실히 고정
환경에 따라 기본 dtype이나 일부 연산이 fp32로 승격되면 캐시도 커질 수 있습니다.
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
)
3-2) 최신 스택에서 KV 캐시 양자화 옵션 활용
최근에는 transformers에서 cache_implementation 및 kv_cache 관련 최적화가 계속 들어오고 있습니다. 버전에 따라 사용 가능 옵션이 다르므로, 아래처럼 “지원되는지 확인하고 켜는 방식”이 안전합니다.
gen_kwargs = dict(
max_new_tokens=256,
do_sample=False,
use_cache=True,
)
# 일부 버전에서만 동작할 수 있음
# 예: cache_implementation="static" 또는 "offloaded" 등
try:
gen_kwargs["cache_implementation"] = "static"
except Exception:
pass
out = model.generate(**inputs, **gen_kwargs)
현실적인 결론은 이렇습니다.
- 가중치 양자화만으로 부족하면 KV 캐시 쪽 최적화가 필수
- 다만 옵션 이름/동작은 버전 의존성이 크므로,
transformers릴리즈 노트와generate시그니처를 꼭 확인하세요
4) use_cache=False로 캐시를 끄고, 대신 짧게 생성하거나 다른 디코딩 전략을 쓰기
use_cache=False는 KV 캐시 메모리를 크게 줄이지만, 매 토큰마다 과거를 다시 계산해서 속도가 급격히 느려질 수 있습니다. 그래도 “어떤 경우에도 OOM은 안 나야 한다” 같은 디버깅/비상 모드로 유용합니다.
out = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
use_cache=False, # KV 캐시 비활성화
)
언제 유효한가
- 생성 길이가 짧은 작업(분류/짧은 요약/태그 생성)
- 디버깅: OOM이 KV 캐시 때문인지 확인
- GPU 메모리가 매우 빡빡한 엣지 디바이스
정확도 보완 팁
캐시를 끄고 생성 길이를 줄이면 답이 부정확해질 수 있습니다. 이때 “한 번의 긴 생성” 대신 “여러 번의 짧은 생성 결과를 합의”하는 방식이 도움이 될 수 있습니다. 예를 들어 Self-Consistency 아이디어는 Chain-of-Thought 막힘? Self-Consistency로 정확도↑에서 개념을 잡을 수 있습니다.
5) 프롬프트를 줄이지 못한다면: 슬라이딩 윈도우/청킹/요약으로 KV 캐시 성장 자체를 막기
“입력 컨텍스트가 길어질 수밖에 없는” 워크로드가 있습니다.
- 장문 대화 메모리
- 로그/코드 리뷰
- 긴 문서 QA
이때는 단순히 max_length로 자르는 순간 품질이 크게 떨어질 수 있으니, 컨텍스트를 구조적으로 줄이는 전략이 필요합니다.
5-1) 대화 메모리 요약(rolling summary)
대화 히스토리를 그대로 누적하지 말고, 일정 길이를 넘으면 요약본으로 압축합니다.
def build_prompt(system, summary, recent_turns):
# summary는 누적 요약(짧게 유지)
# recent_turns는 최근 N턴만 유지
parts = [f"SYSTEM: {system}"]
if summary:
parts.append(f"SUMMARY: {summary}")
parts.extend(recent_turns)
return "\n".join(parts)
5-2) 문서 청킹 + top-k 축소
RAG에서 KV 캐시는 “질문 + 검색 문서” 길이에 좌우됩니다. 보통 성능 대비 메모리 관점에서 다음이 잘 먹힙니다.
- chunk 크기 축소(예: 800 tokens
->300~500 tokens, 여기서->는 반드시 인라인 코드로 표기) - top-k 축소(예: 8
->3~5) - 중복 문서 제거, 문서 헤더/푸터 제거
마무리 체크리스트 (OOM 재발 방지)
max_new_tokens를 먼저 줄였는가- 입력
max_length를 강제하고 있는가(토크나이저truncation=True포함) - 배치/동시성을 제한했는가(특히 서버)
- 가중치 양자화만 믿고 KV 캐시를 방치하고 있진 않은가
- 긴 컨텍스트가 필요하면 요약/청킹으로 구조적으로 줄였는가
KV 캐시는 “한 번 켜두면 알아서 빨라지는 공짜 최적화”처럼 보이지만, 로컬 환경에서는 메모리 상한을 순식간에 넘기는 가장 흔한 원인입니다. 위 5가지를 순서대로 적용하면, 같은 GPU에서도 생성 길이 안정성이 눈에 띄게 좋아질 것입니다.