Published on

Transformers 로컬 LLM OOM 해결 - 4bit+KV캐시

Authors

로컬 GPU에서 transformers로 LLM을 올리다 보면, 모델 로딩은 되는데 추론 중에 갑자기 CUDA OOM이 터지거나, 컨텍스트 길이를 늘리는 순간 VRAM이 폭발하는 경우가 많습니다. 이때 많은 분들이 단순히 max_new_tokens를 줄이거나 더 작은 모델로 후퇴하는데, 실제로는 가중치(weight) 메모리KV 캐시(Key/Value cache) 를 분리해서 접근하면 같은 GPU에서도 훨씬 큰 모델과 긴 컨텍스트를 안정적으로 다룰 수 있습니다.

이 글은 bitsandbytes 4bit 양자화로 가중치 메모리를 줄이고, KV 캐시를 정확히 무엇이 VRAM을 잡아먹는지 계산한 뒤 transformers 설정으로 KV 캐시 증가를 통제하는 방법을 정리합니다. 운영 환경에서 추론 서버를 붙일 계획이라면, 리소스 병목을 계측하고 튜닝하는 관점이 중요합니다. 비슷한 성격의 “원인 분해 후 체크리스트로 해결” 접근은 PostgreSQL pgvector RAG 검색 품질 급락 원인과 해결 체크리스트 글에서도 다뤘습니다.

OOM을 두 종류로 쪼개기: 가중치 vs KV 캐시

LLM 추론 시 GPU 메모리는 크게 3덩어리로 나뉩니다.

  1. 가중치(weight) 메모리: 모델 파라미터를 GPU에 올리는 비용
  2. 활성화(activation) 및 임시 버퍼: forward 중 생기는 텐서, 커널 워크스페이스
  3. KV 캐시: 디코딩(autoregressive generation)에서 이전 토큰의 attention key/value를 저장하는 캐시

여기서 로컬 추론에서 체감 OOM의 대부분은 1번과 3번입니다.

  • 모델 로딩 단계에서 OOM이 나면 보통 가중치가 VRAM을 초과한 것입니다.
  • 생성 길이를 늘릴수록, 혹은 컨텍스트가 길수록 OOM이 나면 대개 KV 캐시가 선형으로 증가한 것입니다.

특히 KV 캐시는 “토큰 수에 비례해서 계속 늘어나는 메모리”라서, 4bit 양자화로 모델을 겨우 올려도 컨텍스트를 늘리는 순간 다시 터질 수 있습니다.

bitsandbytes 4bit로 가중치 메모리 줄이기

bitsandbytes 4bit 양자화는 가중치를 fp16 혹은 bf16 대신 4bit로 저장해 VRAM을 크게 절약합니다. 일반적으로 7B급은 8GB에서도 가능해지고, 13B급도 12GB 전후에서 “조건부로” 올라오는 경우가 많습니다.

핵심은 다음 설정 조합입니다.

  • load_in_4bit=True
  • bnb_4bit_quant_type="nf4" (보통 품질 대비 효율이 좋음)
  • bnb_4bit_compute_dtype=torch.bfloat16 또는 torch.float16
  • bnb_4bit_use_double_quant=True (추가 절감)

코드: 4bit 양자화 로딩 기본 템플릿

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_id = "meta-llama/Llama-2-7b-chat-hf"  # 예시

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

model.eval()

자주 겪는 함정

  • GPU가 bf16을 잘 못 받는 경우가 있습니다. 그럴 땐 bnb_4bit_compute_dtype=torch.float16로 바꾸세요.
  • device_map="auto"는 레이어를 GPU와 CPU로 분산(offload)할 수 있습니다. 다만 CPU로 많이 넘어가면 속도가 급격히 떨어집니다. “일단 OOM을 피하는 응급처치”로는 유효하지만, 목표 성능이 있다면 GPU에 최대한 올리는 방향으로 조정해야 합니다.

KV 캐시가 왜 이렇게 큰가: 대략적인 메모리 감 잡기

KV 캐시는 레이어마다, 그리고 토큰마다 Key와 Value를 저장합니다. 대략적인 스케일만 잡아도 튜닝 방향이 명확해집니다.

  • KV 캐시 크기 ~ O(num_layers * seq_len * hidden_size)
  • 여기에 Key와 Value 2개가 있으니 대략 2배
  • dtype이 fp16이면 2바이트, fp32면 4바이트

즉, 컨텍스트 길이(seq_len)를 2배로 늘리면 KV 캐시도 거의 2배로 늘어납니다.

여기서 중요한 포인트는 가중치를 4bit로 줄여도 KV 캐시는 보통 fp16/bf16로 유지된다는 점입니다. 그래서 “4bit로 로딩했는데도 긴 컨텍스트에서 OOM”이 흔합니다.

Transformers에서 KV 캐시를 통제하는 실전 옵션

KV 캐시 관련해서 현실적으로 가장 효과가 큰 레버는 다음입니다.

  1. 컨텍스트를 줄인다: 입력 토큰(input_ids) 길이 자체를 제한
  2. 생성 길이를 줄인다: max_new_tokens 제한
  3. 캐시를 끈다: use_cache=False (메모리는 줄지만 속도는 크게 느려짐)
  4. 어텐션 구현을 바꾼다: sdpa 또는 FlashAttention 계열로 임시 버퍼를 줄이고 안정성을 높임

transformers 버전과 모델 아키텍처에 따라 지원 옵션이 조금씩 다르지만, 아래 템플릿은 대부분의 CausalLM에서 유효합니다.

코드: 생성 설정에서 캐시와 길이 제한

import torch

prompt = "요약해줘: KV 캐시가 왜 OOM을 유발하는지"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.inference_mode():
    out = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False,
        use_cache=True,
    )

print(tokenizer.decode(out[0], skip_special_tokens=True))
  • OOM이 나는 상황이라면 우선 max_new_tokens를 줄여 “생성 단계에서의 KV 캐시 증가”를 제한하세요.
  • 입력이 길다면 tokenizer 단계에서 truncation을 강제하는 것도 중요합니다.

코드: 입력 길이 강제(트렁케이션)

max_input_tokens = 2048

inputs = tokenizer(
    prompt,
    return_tensors="pt",
    truncation=True,
    max_length=max_input_tokens,
).to(model.device)

캐시를 꺼서 OOM을 피하는 방법과 트레이드오프

use_cache=False는 KV 캐시를 쌓지 않으므로 메모리는 확 줄어듭니다. 하지만 매 토큰 생성마다 과거 토큰을 다시 계산해야 해서 속도가 크게 느려집니다. “OOM 회피가 최우선”인 디버깅 국면에서는 쓸 수 있지만, 서비스/반복 작업에는 보통 부적합합니다.

코드: use_cache=False

with torch.inference_mode():
    out = model.generate(
        **inputs,
        max_new_tokens=128,
        do_sample=False,
        use_cache=False,
    )

attention 구현 선택으로 임시 메모리 줄이기

OOM이 KV 캐시 때문이라고 생각했는데, 실제로는 attention 연산 중 임시 버퍼가 커서 터지는 경우도 있습니다. 특히 긴 시퀀스에서 attention은 계산량과 메모리 모두 부담이 큽니다.

최근 transformers는 PyTorch의 SDPA 경로를 잘 활용합니다. 가능하다면 다음을 시도하세요.

  • attn_implementation="sdpa"
  • GPU와 드라이버가 받쳐주면 FlashAttention 계열(환경에 따라 다름)

코드: sdpa로 로딩

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
)

이 설정은 KV 캐시 자체를 줄이진 않지만, 긴 컨텍스트에서의 임시 메모리 사용량과 커널 효율에 영향을 줘서 “경계선에서 OOM이 나던 케이스”를 살리는 데 도움이 됩니다.

VRAM 점유를 확인하는 최소 계측 코드

OOM 해결은 감으로 하면 끝이 없습니다. “로딩 직후”와 “생성 도중”을 나눠서 VRAM을 찍어보면 원인이 명확해집니다.

코드: 단계별 VRAM 확인

import torch

def vram_mb():
    allocated = torch.cuda.memory_allocated() / 1024 / 1024
    reserved = torch.cuda.memory_reserved() / 1024 / 1024
    return allocated, reserved

print("after load", vram_mb())

inputs = tokenizer("hello", return_tensors="pt").to(model.device)

with torch.inference_mode():
    _ = model.generate(**inputs, max_new_tokens=256, use_cache=True)

print("after generate", vram_mb())
  • 로딩 직후에 이미 한계면 4bit/device_map/더 작은 모델이 필요합니다.
  • 로딩은 괜찮은데 생성 후 크게 늘면 KV 캐시가 주범입니다.

OOM을 줄이는 조합 레시피(현실적인 우선순위)

아래 순서대로 적용하면 시행착오가 줄어듭니다.

  1. 4bit 양자화 적용: nf4 + double quant 조합부터
  2. 입력 길이 상한 설정: max_length로 강제 트렁케이션
  3. 생성 길이 제한: max_new_tokens를 보수적으로
  4. sdpa 사용: 경계 OOM 완화
  5. 그래도 안 되면
    • use_cache=False로 기능 확인(속도 포기)
    • CPU offload를 늘리기(device_map 조정)
    • 모델 자체를 더 작은 것으로 변경

운영 관점에서는 “리소스 한도 내에서의 안정성”이 중요합니다. 예를 들어 Cloud Run 같은 환경에서 cold start나 메모리 한도에 부딪히는 문제를 다룰 때도 비슷하게 체크리스트가 필요합니다. 관련해서는 GCP Cloud Run 503·Cold Start 지연 최소화 7가지도 함께 참고할 만합니다.

(보너스) RAG에서 컨텍스트를 무작정 늘리면 망하는 이유

로컬 LLM OOM을 겪는 팀은 종종 RAG를 붙이면서 “검색 결과를 많이 넣으면 답이 좋아지겠지”라고 생각해 컨텍스트를 과도하게 키웁니다. 하지만 컨텍스트를 늘리면

  • KV 캐시가 커져 OOM 위험 증가
  • attention 비용이 커져 지연 증가
  • 프롬프트가 길어져 오히려 정답률이 떨어지는 경우도 발생

즉, 검색 품질을 올려서 적은 토큰으로 더 좋은 근거를 넣는 것이 정공법입니다. 이 관점은 PostgreSQL pgvector RAG 검색 품질 급락 원인과 해결 체크리스트에서 더 깊게 다룹니다.

마무리: 4bit는 시작이고, KV 캐시가 승부처다

정리하면,

  • bitsandbytes 4bit 양자화는 가중치 메모리를 줄여 “모델을 올리는 문제”를 해결합니다.
  • 하지만 긴 컨텍스트/긴 생성에서 터지는 OOM은 대부분 KV 캐시 증가가 원인입니다.
  • 따라서 입력 길이, 생성 길이, 캐시 사용 여부, attention 구현을 조합해 토큰당 메모리 증가를 통제해야 합니다.

로컬에서 안정적으로 돌아가는 설정을 만들면, 이후에는 배포 환경에서 GPU 종류와 VRAM 크기에 맞춰 “안전한 상한선”을 잡는 작업만 남습니다. OOM을 단순히 운으로 피하지 말고, 가중치와 KV 캐시를 분해해서 계측하고 조정하는 습관을 들이면 같은 하드웨어에서도 훨씬 멀리 갈 수 있습니다.