Published on

Transformers 로컬 LLM OOM 해결 - 4bit+KV 캐시

Authors

로컬 GPU(예: 8GB~24GB)에서 transformers로 LLM을 추론하다 보면 가장 먼저 부딪히는 게 CUDA OOM(Out Of Memory)입니다. 특히 "모델은 겨우 올라가는데, 조금만 길게 생성하면 터지는" 패턴이 많습니다. 이건 단순히 가중치(weight)만의 문제가 아니라, KV 캐시(Key/Value cache)활성화(activation) 가 생성 길이에 비례해 커지기 때문입니다.

이 글에서는 OOM을 두 축으로 나눠 해결합니다.

  • 가중치 메모리 줄이기: 4bit 양자화(bitsandbytes)로 VRAM 압축
  • 생성 중 메모리 폭증 줄이기: KV 캐시 사용 방식/길이/배치/어텐션 구현 최적화

아래 내용을 따라 하면 같은 GPU에서도 "올라가기만 하는 모델"에서 "실제로 대화가 되는 모델"로 체감이 달라집니다.

시스템이 진짜로 OOM Kill을 내는 상황(리눅스 커널 OOM, cgroup 제한 등)까지 의심된다면 원인 추적은 별도로 필요합니다. 이 경우는 리눅스 OOM Kill 원인 추적 - dmesg·cgroup·journalctl 도 함께 확인하세요.

OOM의 두 얼굴: 가중치 vs KV 캐시

LLM 추론 VRAM 사용량은 크게 3가지로 나뉩니다.

  1. 모델 가중치(Weights): 로딩 시 거의 고정
  2. KV 캐시: 생성 토큰 수(컨텍스트 길이)에 비례해 증가
  3. 임시 버퍼/활성화: 구현(어텐션 커널, dtype, 배치)에 따라 변동

많은 분들이 "7B는 8GB에서 안 되네"라고 결론내리지만, 실제로는

  • 4bit로 가중치는 들어가도
  • 긴 컨텍스트/긴 출력/큰 배치 때문에 KV 캐시가 터져서 실패

하는 경우가 많습니다.

KV 캐시는 왜 커질까

Self-Attention은 매 토큰마다 과거 토큰의 K/V를 재사용하기 위해 캐시를 쌓습니다. 대략적인 직관은 이렇습니다.

  • 컨텍스트 길이 seq_len 이 커질수록 KV 캐시가 선형 증가
  • 레이어 수, 헤드 수, head_dim이 클수록 증가
  • batch_size 가 커질수록 증가

즉 "긴 대화"를 할수록 VRAM이 계속 늘어납니다. 그래서 max_new_tokenscontext 관리가 OOM 방지의 핵심입니다.

1단계: 4bit 양자화로 가중치 VRAM 줄이기

가장 효과가 큰 1차 처방은 4bit 로딩입니다. bitsandbytes를 쓰면 transformers에서 매우 간단히 적용됩니다.

설치

pip install -U transformers accelerate bitsandbytes

4bit 로딩 예제 (NF4 권장)

아래 예시는 AutoModelForCausalLM을 4bit로 로드합니다.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "meta-llama/Llama-2-7b-chat-hf"  # 예시

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # Ampere+면 bfloat16 권장
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.eval()

포인트

  • bnb_4bit_quant_type 는 보통 nf4가 품질 대비 효율이 좋습니다.
  • bnb_4bit_compute_dtypebfloat16 또는 float16.
  • 4bit는 가중치 메모리를 크게 줄이지만, KV 캐시는 여전히 FP16/BF16 로 쌓이는 경우가 많아(모델/설정에 따라) 긴 생성에서 OOM이 날 수 있습니다.

2단계: KV 캐시로 인한 OOM을 잡는 실전 옵션

(1) 생성 길이 제한: max_new_tokens 가 최우선

무작정 max_new_tokens=2048 같은 설정은 KV 캐시를 급격히 키웁니다. 먼저 안전장치를 걸어두세요.

inputs = tokenizer("Explain KV cache in simple terms.", return_tensors="pt").to(model.device)

with torch.inference_mode():
    out = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        use_cache=True,
    )

print(tokenizer.decode(out[0], skip_special_tokens=True))
  • use_cache=True 는 보통 속도에 유리하지만, 캐시가 커집니다.
  • 다만 use_cache=False 로 끄면 매 토큰마다 과거를 다시 계산해 속도는 크게 느려지고, 활성화/연산 부담이 늘 수 있어 만능 해결책은 아닙니다.

(2) 입력 컨텍스트 관리: "대화 기록"이 VRAM을 먹는다

채팅 앱을 만들 때 흔히 전체 히스토리를 매번 프롬프트에 붙입니다. 이러면 seq_len이 계속 증가해 KV 캐시가 커집니다.

실전에서는 다음 중 하나를 적용합니다.

  • 최근 N턴만 유지(슬라이딩 윈도우)
  • 요약 메모리(중요 정보만 압축)
  • RAG로 필요한 문서만 주입

간단한 슬라이딩 윈도우 예시입니다.

def build_prompt(messages, keep_last=6):
    recent = messages[-keep_last:]
    return "\n".join([f"{m['role']}: {m['content']}" for m in recent])

messages = [
    {"role": "user", "content": "..."},
    # ... 대화 누적
]

prompt = build_prompt(messages, keep_last=6)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

(3) 배치/동시성 줄이기: batch_size=1로 먼저 고정

로컬 서비스로 만들면 "동시 요청"이 OOM 트리거가 됩니다.

  • 배치가 커지면 KV 캐시는 배치에 비례해 늘어납니다.
  • 작은 GPU에서는 동시성을 큐로 제한하는 게 안전합니다.

동시성 문제는 LLM만의 이슈가 아니라 서버 전반의 자원 관리 문제이기도 합니다. 웹 성능과는 결이 다르지만, 병목을 쪼개는 접근은 유사합니다. 프론트 성능 쪽 감각이 필요하면 Chrome INP 개선 - Long Task 분해 실전 가이드 도 참고가 됩니다.

(4) Flash Attention / SDPA 사용: 어텐션 메모리 효율 개선

PyTorch 2 계열에서는 SDPA(Scaled Dot-Product Attention) 경로가 활성화되면 메모리/속도가 좋아지는 경우가 많습니다.

환경에 따라 차이가 있지만, 다음을 점검하세요.

  • PyTorch 버전 업(가능하면 2.1+)
  • transformers 최신
  • GPU가 지원하는 커널 사용

예시(모델 설정으로 강제하는 방식은 모델마다 다를 수 있음):

import torch

torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)

이 설정은 "가능하면" flash/mem-efficient 경로를 타게 유도합니다. 실제 적용 여부는 GPU/드라이버/빌드 옵션에 따라 달라집니다.

(5) KV 캐시 자체를 더 줄이는 방법: GQA/MQA 모델 선택

모델 아키텍처에 따라 KV 캐시 크기가 다릅니다.

  • MHA(일반 Multi-Head Attention): 헤드별로 KV를 유지
  • GQA(Grouped Query Attention): Query 헤드는 많지만 KV 헤드는 적어 캐시가 줄어듦
  • MQA(Multi-Query Attention): KV 헤드가 1개 수준이라 캐시가 크게 줄어듦

즉, 같은 파라미터 수라도 GQA/MQA 채택 모델이 긴 컨텍스트에서 더 버팁니다. 로컬 구동이 목적이면 모델 선택 단계에서 이 차이가 크게 납니다.

(6) max_position_embeddings 와 RoPE 스케일링 착각 주의

"컨텍스트 8k 지원" 같은 설정을 적용하면, 모델이 더 긴 입력을 받을 수는 있어도 그만큼 KV 캐시가 더 커져 OOM이 더 빨리 날 수 있습니다. "길게 받게 만드는 것"과 "길게 받아도 메모리가 버티는 것"은 별개입니다.

3단계: OOM 디버깅 루틴(재현 가능한 측정)

OOM은 감으로 잡으면 끝이 없습니다. 아래처럼 매 단계에서 VRAM을 찍어보면 원인이 선명해집니다.

import torch

def vram(msg=""):
    torch.cuda.synchronize()
    alloc = torch.cuda.memory_allocated() / 1024**2
    reserved = torch.cuda.memory_reserved() / 1024**2
    print(f"{msg} | allocated={alloc:.1f}MB reserved={reserved:.1f}MB")

vram("after load")

prompt = "Write a detailed explanation about KV cache."
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

vram("after tokenize")

with torch.inference_mode():
    out = model.generate(**inputs, max_new_tokens=128, use_cache=True)

vram("after generate")

여기서 패턴이 보통 이렇게 나뉩니다.

  • after load부터 이미 한계면: 가중치가 문제 → 4bit, 더 작은 모델, 오프로딩
  • after generate에서 급증하면: KV 캐시/생성 길이/배치가 문제max_new_tokens, 컨텍스트 축소, 동시성 제한

4단계: 자주 쓰는 안정 조합(권장 프리셋)

8GB GPU 현실 프리셋

  • 7B급: 4bit 필수
  • max_new_tokens: 128~256부터 시작
  • 대화 히스토리: 최근 4~6턴만
  • 동시성: 1(큐잉)
gen_kwargs = dict(
    max_new_tokens=192,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    use_cache=True,
)

12GB~24GB 프리셋

  • 7B~13B: 4bit + BF16 compute
  • max_new_tokens: 256~512
  • 필요하면 입력 길이 제한(예: 토큰 2k~4k)
inputs = tokenizer(
    prompt,
    return_tensors="pt",
    truncation=True,
    max_length=2048,
).to(model.device)

truncation은 품질에 영향을 주지만, OOM을 막는 확실한 안전장치입니다.

5단계: 그래도 OOM이면 체크리스트

  • device_map="auto"가 의도치 않게 일부 레이어를 CPU로 보내면서 속도/메모리 패턴이 꼬이지 않았는지
  • 백그라운드에서 다른 프로세스가 VRAM을 쓰고 있지 않은지(nvidia-smi)
  • torch_dtypebnb_4bit_compute_dtype가 불필요하게 float32로 올라가 있지 않은지
  • 생성 파이프라인에서 배치가 커지지 않았는지(여러 요청을 리스트로 한번에 넣는 코드 등)
  • 스트리밍 UI에서 "이전 출력"을 계속 다시 붙여서 프롬프트가 커지지 않았는지

서버 운영 관점에서 "원인은 하나가 아니라 조합"인 경우가 많습니다. 예를 들어 네트워크 레벨 리셋이나 타임아웃이 겉으로 드러나지만, 실제 원인이 리소스 압박인 경우도 있습니다. 인프라 트러블슈팅 감각이 필요하면 EKS ALB Ingress 502 Target reset 원인과 해결 같은 글이 디버깅 접근에 도움이 됩니다.

결론: 4bit는 시작, 승부는 KV 캐시 관리

  • 4bit 양자화는 로컬 LLM OOM 해결의 출발점입니다. 가중치 VRAM을 확 줄여 "로딩"을 가능하게 합니다.
  • 하지만 실제 서비스/대화에서 OOM을 터뜨리는 주범은 KV 캐시인 경우가 많습니다.
  • 따라서 max_new_tokens, 입력 컨텍스트 길이, 동시성(배치), 어텐션 커널(SDPA/Flash)을 함께 최적화해야 안정적으로 돌아갑니다.

다음 단계로는

  • 프롬프트/대화 메모리 전략(요약, RAG)
  • 모델 선택(GQA/MQA 여부)
  • 더 공격적인 캐시 최적화(라이브러리/백엔드 교체)

까지 확장하면, 같은 GPU에서도 "쓸 수 있는" 로컬 LLM 환경을 만들 수 있습니다.