- Published on
Transformers 로컬 LLM OOM - 4-bit·KV 캐시 튜닝
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
로컬 GPU에서 transformers로 LLM을 띄우다 보면, 모델은 로드되는데 첫 토큰 생성 시점에 갑자기 OOM이 나거나, 배치만 조금 키웠을 뿐인데 VRAM이 폭증하는 일을 자주 겪습니다. 이 문제는 단순히 “모델 파라미터가 커서”만이 아니라, KV 캐시(Key/Value cache), 시퀀스 길이, 배치/동시성, attention 구현 방식이 합쳐져 터지는 경우가 많습니다.
이 글에서는 OOM을 파라미터 메모리와 **런타임 메모리(KV 캐시 포함)**로 나눠서 진단하고, 가장 효과가 큰 처방인 4-bit 양자화와 KV 캐시 튜닝을 중심으로, 로컬 환경에서 바로 적용 가능한 설정과 코드 예제를 제공합니다.
OOM의 2가지 축: 파라미터 vs 런타임
1) 파라미터(가중치) 메모리
- 모델 파일 크기와 거의 비례합니다.
- FP16/BF16 로딩이면 VRAM을 크게 먹습니다.
- 4-bit/8-bit 양자화로 가장 드라마틱하게 줄일 수 있습니다.
2) 런타임 메모리
- 대표적으로 KV 캐시가 핵심입니다.
max_new_tokens가 크거나,prompt가 길거나, 배치가 커지면 선형적으로 늘어납니다.- flash-attn/SDPA 같은 attention 커널 선택에 따라서도 피크 메모리가 달라집니다.
OOM이 “모델 로드 후 생성 시점”에 터지면, 대부분 KV 캐시 또는 attention 중간 텐서 피크가 원인입니다.
KV 캐시가 왜 VRAM을 폭발시키나
KV 캐시는 디코딩(autoregressive generation)에서 이전 토큰의 attention을 재사용하기 위해, 각 레이어마다 key/value 텐서를 저장해두는 구조입니다. 대략적인 증가 요인은 다음과 같습니다.
- 레이어 수가 많을수록 증가
- hidden size가 클수록 증가
- 헤드 수가 많을수록 증가
- 컨텍스트 길이(
prompt_len + generated_len)가 길수록 증가 - 배치가 커질수록 증가
즉, 4-bit로 가중치를 줄여도 KV 캐시는 기본적으로 FP16/BF16으로 남는 경우가 많아, “모델은 4-bit로 로드되는데도 OOM”이 충분히 발생합니다.
4-bit 양자화: 가장 먼저 적용할 처방
bitsandbytes 기반 4-bit 로딩은 로컬 OOM 해결에서 가장 ROI가 큽니다. 아래는 transformers에서 가장 흔히 쓰는 NF4 설정 예시입니다.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
model_id = "meta-llama/Llama-2-7b-chat-hf" # 예시
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16, # GPU가 BF16 지원이면 권장
)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.eval()
4-bit에서 자주 겪는 함정
1) torch_dtype와 bnb_4bit_compute_dtype 불일치
bnb_4bit_compute_dtype는 matmul 등 연산 dtype에 영향을 줍니다.torch_dtype는 일부 모듈 로딩 dtype 힌트로도 쓰입니다.- 권장: BF16 가능 GPU면 둘 다 BF16로 맞추는 편이 안전합니다.
2) device_map="auto"가 오히려 느리거나 불안정
- VRAM이 부족하면 CPU offload가 섞이면서 속도와 안정성이 흔들릴 수 있습니다.
- 로컬 단일 GPU라면
device_map={"": 0}로 고정하는 것도 방법입니다.
KV 캐시 튜닝: OOM의 진짜 트리거 잡기
4-bit로도 OOM이 난다면, 다음 우선순위로 KV 캐시를 줄여야 합니다.
1) 생성 길이 제한: max_new_tokens와 입력 길이 관리
KV 캐시는 “지금까지의 총 토큰 수”에 비례합니다. 즉,
- 프롬프트가 길면 시작부터 이미 KV 캐시가 큽니다.
max_new_tokens가 크면 생성하면서 계속 커집니다.
inputs = tokenizer(
"긴 프롬프트...",
return_tensors="pt",
truncation=True,
max_length=2048, # 모델 컨텍스트 한도보다 작게 운영
).to(model.device)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
use_cache=True,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
실무 팁:
- “OOM이 가끔 난다”는 케이스는 대부분 프롬프트 길이 분산이 크거나, 특정 요청이 유난히 길어서 발생합니다.
- 서버형이라면 요청별 토큰 상한을 강제하고, 초과 입력은 요약/리트리벌로 우회하는 설계가 필요합니다.
2) 배치와 동시성: KV 캐시는 배치에 선형으로 증가
로컬에서 여러 요청을 한 번에 처리하거나, 스트리밍을 위해 내부적으로 배치를 쌓으면 KV 캐시가 그대로 곱해집니다.
- 단일 GPU 로컬 테스트는 배치
1부터 시작 - 동시성은 큐로 직렬화하거나, 작은 마이크로배치로 제한
3) use_cache=False는 최후의 수단
KV 캐시를 꺼버리면 메모리는 줄지만, 디코딩이 매 토큰마다 전체 시퀀스를 다시 attention 하므로 속도가 크게 떨어집니다.
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=128,
use_cache=False,
)
정리하면:
- OOM 회피 목적이면 효과는 큼
- 하지만 체감 성능이 크게 떨어져 “로컬에서만 잠깐” 쓰는 정도로 권장
4) attention 커널 선택: SDPA/FlashAttention로 피크 완화
PyTorch의 SDPA(scaled dot-product attention) 또는 FlashAttention 계열은 메모리 피크를 낮추는 데 도움이 됩니다. 환경에 따라 아래처럼 설정을 강제할 수 있습니다.
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.backends.cuda.matmul.allow_tf32 = True
# PyTorch 버전에 따라 SDPA/Flash 설정 방식이 다를 수 있음
또한 transformers 모델에 따라 다음 옵션이 동작합니다.
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
attn_implementation="sdpa", # 지원 모델에서 효과
)
주의:
- 모델/버전 조합에 따라
attn_implementation지원 여부가 다릅니다. - FlashAttention은 별도 설치/컴파일이 필요한 경우가 많아 로컬 환경에서 허들이 있을 수 있습니다.
실제 OOM 진단 루틴: 어디서 터지는지 숫자로 보기
OOM을 “감”으로 잡으면 끝이 없습니다. 아래처럼 최소한의 계측을 넣어, 로드 직후와 생성 전후 VRAM을 비교하세요.
import torch
def vram(tag: str):
if not torch.cuda.is_available():
print(tag, "CUDA not available")
return
torch.cuda.synchronize()
alloc = torch.cuda.memory_allocated() / 1024**2
reserv = torch.cuda.memory_reserved() / 1024**2
max_alloc = torch.cuda.max_memory_allocated() / 1024**2
print(f"[{tag}] alloc={alloc:.1f}MB reserved={reserv:.1f}MB max_alloc={max_alloc:.1f}MB")
# 사용 예
vram("after_load")
with torch.inference_mode():
vram("before_generate")
out = model.generate(**inputs, max_new_tokens=256, use_cache=True)
vram("after_generate")
관찰 포인트:
after_load는 주로 파라미터 메모리after_generate에서 급증하면 KV 캐시/attention 피크reserved가 과도하게 커지는 경우는 allocator 단편화 영향도 의심
이런 식의 “스택 트레이스 기반 원인 추적” 관점은 다른 장애 진단에도 동일하게 적용됩니다. 예를 들어 접근 거부를 IAM 정책 최소화로 추적하는 방식은 AWS IAM AccessDenied 스택추적과 정책 최소화 글에서 유사한 사고방식을 참고할 수 있습니다.
자주 쓰는 처방 조합(로컬 단일 GPU 기준)
아래 조합은 “일단 돌아가게” 만드는 데 효과가 큽니다.
- 4-bit NF4 + BF16 compute
- 입력
max_length를 보수적으로 제한(예: 2048) max_new_tokens를 작게 시작(예: 128~256)- 배치
1로 검증 후 확장 - 가능하면
attn_implementation="sdpa"
예시:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
model_id = "mistralai/Mistral-7B-Instruct-v0.2" # 예시
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
quantization_config=bnb,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
).eval()
prompt = "You are a helpful assistant. Summarize the following text..."
inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
with torch.inference_mode():
y = model.generate(
**inputs,
max_new_tokens=192,
do_sample=False,
use_cache=True,
)
print(tok.decode(y[0], skip_special_tokens=True))
“4-bit인데도 OOM”일 때 체크리스트
1) 컨텍스트 윈도우를 과신하지 않았나
모델이 8k/32k 컨텍스트를 지원하더라도, 로컬 GPU에서는 KV 캐시가 먼저 한계에 도달할 수 있습니다. 운영 상한을 더 낮게 잡고, 긴 입력은 전처리로 분해하세요.
2) 스트리밍/동시성으로 KV 캐시가 누적되지 않았나
서빙 코드에서 세션을 유지하거나, 여러 요청을 한 프로세스에서 동시에 생성하면 KV 캐시가 “요청 수만큼” 잡힙니다. 요청 단위로 객체 생명주기를 명확히 하고, 동시성을 제한하세요.
3) 메모리 단편화가 의심되나
긴 시간 실행 후 OOM이 잦아지면 단편화 가능성이 있습니다.
- 프로세스 재시작이 가장 확실한 해법
PYTORCH_CUDA_ALLOC_CONF튜닝도 도움
대규모 로그를 다루다 메모리가 폭주할 때 “원인 분리 후 단계적으로 줄이는” 접근은 Pandas read_csv 메모리 폭주 - 10GB 로그 튜닝에서도 비슷한 패턴으로 설명합니다.
모델을 더 줄이는 추가 옵션들
1) 8-bit로 타협하기
4-bit에서 품질/안정성이 애매하거나 특정 연산에서 문제가 나면 8-bit가 현실적인 절충안입니다.
bnb = BitsAndBytesConfig(load_in_8bit=True)
2) CPU 오프로딩
VRAM이 정말 부족하면 일부 레이어를 CPU로 내릴 수 있지만, 속도 저하가 큽니다. 로컬 개발에서 “기능 확인” 용도로만 권장합니다.
3) 더 작은 모델, 더 짧은 컨텍스트
결국 하드웨어 한계가 명확하면 모델 사이즈를 낮추는 게 총비용 관점에서 이득입니다. 7B에서 3B/1.5B로 내리고, 프롬프트를 최적화하는 쪽이 실사용 만족도가 더 높을 때도 많습니다.
마무리: OOM을 “설정”이 아니라 “구조”로 해결하기
로컬 LLM OOM은 단순히 VRAM이 부족한 문제가 아니라,
- 가중치(파라미터) 메모리
- KV 캐시(컨텍스트 길이, 배치, 동시성)
- attention 커널과 피크 메모리
이 3가지를 분리해서 보는 순간 해결이 빨라집니다.
실행 순서로 정리하면:
- 4-bit NF4로 파라미터 메모리부터 줄이기
- 입력 길이와
max_new_tokens로 KV 캐시 상한을 설계하기 - 배치/동시성 제어로 선형 증가를 막기
- SDPA/FlashAttention 등으로 피크를 낮추기
- 계측 코드로 “어디서 늘었는지”를 수치로 확인하기
성능 튜닝과 장애 진단은 결국 관측과 가설 검증의 반복입니다. 렌더링 폭증을 캐시와 메모이제이션 관점에서 쪼개는 방식이 궁금하다면 Next.js App Router 렌더링 폭증 진단 - RSC 캐시·useMemo 글도 함께 보면, 병목을 구조적으로 분해하는 감각을 다른 분야에도 확장할 수 있습니다.