Published on

Transformers 로컬 LLM OOM 방지 - 4bit+KV 캐시

Authors

로컬 GPU에서 transformers로 LLM을 띄우다 보면, 모델 로딩은 되는데 생성 중에 갑자기 CUDA OOM이 터지거나, 배치/컨텍스트 길이를 조금만 올려도 바로 메모리가 바닥나는 경험을 자주 합니다. 이 문제는 단순히 “VRAM이 부족해서”로 끝나지 않습니다. 가중치(Weights), KV 캐시, 활성화(Activations), 그리고 런타임 메모리 단편화까지 여러 요인이 겹쳐서 발생합니다.

이 글에서는 OOM을 줄이는 가장 실전적인 조합인 4bit 양자화(QLoRA 계열, bitsandbytes) + KV 캐시 최적화를 중심으로, 로컬 추론에서 안정적으로 돌아가게 만드는 체크리스트와 코드 예제를 정리합니다.

참고: 운영 환경에서 OOM은 종종 재시작 루프를 유발합니다. 쿠버네티스에서 이런 현상이 CrashLoopBackOff로 보일 때는 메모리 원인과 프로브 설정이 함께 얽히기도 하니, 필요하면 K8s CrashLoopBackOff - liveness probe 오탐 해결도 같이 확인해보세요.

OOM의 진짜 범인: 가중치보다 KV 캐시가 더 무섭다

대부분 “모델 파라미터가 크니까 OOM”이라고 생각하지만, 긴 컨텍스트 + 긴 생성을 하면 KV 캐시가 VRAM을 빠르게 잠식합니다.

  • 가중치 메모리: 모델 로딩 시 거의 고정
  • KV 캐시 메모리: batch_size * (prompt_len + gen_len)에 비례하여 계속 증가

대략적인 KV 캐시 크기 감각만 잡아도 튜닝 방향이 명확해집니다.

KV 캐시 메모리 대략 계산식

모델마다 세부는 다르지만, 흔한 디코더 전용 Transformer에서 토큰당 KV 캐시 비용은 대략 다음과 같이 생각할 수 있습니다.

  • 토큰당 저장 요소 수: 2 * num_layers * hidden_size
    • 2는 K와 V
  • 바이트 수: dtype_bytes (예: fp16은 2 bytes)

즉,

  • 전체 KV 캐시 바이트 ≈ batch * seq_len * 2 * layers * hidden * dtype_bytes

예를 들어 layers=32, hidden=4096, fp16(2 bytes), batch=1, seq_len=4096이면:

  • 1 * 4096 * 2 * 32 * 4096 * 2 bytes4GB 수준까지도 커질 수 있습니다(모델 구조에 따라 달라짐).

결론은 간단합니다.

  • 가중치를 4bit로 줄여도
  • KV 캐시가 fp16로 크게 잡히면
  • 생성 중 OOM이 계속 날 수 있습니다.

따라서 “4bit 로딩”과 “KV 캐시 전략”은 반드시 같이 봐야 합니다.

1) 4bit 양자화로 가중치 메모리부터 줄이기

가장 손쉬운 첫 단계는 bitsandbytes 기반 4bit 로딩입니다. 핵심은 BitsAndBytesConfig 설정이며, 로컬 GPU에서는 보통 nf4 + double quant 조합이 안정적입니다.

설치

pip install -U transformers accelerate bitsandbytes

4bit 로딩 코드 (Transformers)

아래 코드는 4bit로 모델을 로드하고, 추론에 필요한 기본 설정을 포함합니다.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_id = "meta-llama/Llama-2-7b-hf"  # 예시

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # Ampere 이상이면 bf16 추천
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.eval()

자주 하는 실수

  • torch_dtypefloat16로 강제했는데 GPU가 bf16에 더 최적화된 경우
  • device_map을 수동으로 잘못 지정해서 GPU와 CPU 오프로딩이 비효율적으로 섞이는 경우
  • 모델 로딩은 되는데 생성에서 OOM이 나는 경우(대부분 KV 캐시 이슈)

4bit는 “로딩을 가능하게” 만들지만, “긴 대화/긴 생성이 안정적”인지는 별개입니다.

2) KV 캐시 최적화: OOM 방지의 핵심

KV 캐시는 크게 두 가지 방향으로 줄일 수 있습니다.

  1. KV 캐시 자체를 더 작은 dtype/포맷으로 저장
  2. KV 캐시가 커지지 않게 시퀀스 길이와 생성 길이를 제어

2-1) use_cache와 생성 길이 제한

가장 먼저 확인할 것은 generate 설정입니다.

  • max_new_tokens를 과하게 주면 KV 캐시가 계속 커집니다.
  • max_length는 입력 길이까지 포함하므로 혼동하기 쉽습니다.
prompt = "요약: Transformers에서 OOM을 줄이는 방법은"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.inference_mode():
    out = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        use_cache=True,
    )

print(tokenizer.decode(out[0], skip_special_tokens=True))
  • use_cache=True는 보통 속도를 크게 올리지만 KV 캐시 메모리를 사용합니다.
  • 반대로 use_cache=False는 메모리는 줄 수 있지만 속도가 크게 느려질 수 있습니다.

즉, OOM을 피하려고 use_cache=False로 도망가기보다는, 아래의 “캐시 포맷 최적화”를 먼저 고려하는 편이 실전적입니다.

2-2) cache_implementation로 캐시 구현 바꾸기

최근 transformers는 캐시 구현을 선택할 수 있는 옵션을 제공합니다(버전에 따라 지원 범위가 다릅니다). 환경에 따라 static이나 sliding_window 같은 방식이 도움이 될 수 있습니다.

다만 이 옵션은 모델/버전 호환성이 있으니, 아래처럼 “가능하면 적용”하는 식으로 접근하는 게 안전합니다.

gen_kwargs = dict(
    max_new_tokens=256,
    do_sample=False,
    use_cache=True,
)

# 일부 버전에서 동작: 캐시 구현을 명시
# 동작하지 않으면 해당 인자를 제거하세요.
try:
    gen_kwargs["cache_implementation"] = "static"
except Exception:
    pass

with torch.inference_mode():
    out = model.generate(**inputs, **gen_kwargs)

핵심은 “캐시가 어떤 형태로 잡히는지”를 통제하는 것입니다. 특히 긴 컨텍스트에서 캐시 할당이 커질 때, 구현에 따라 메모리 피크가 달라질 수 있습니다.

2-3) 긴 대화는 sliding window 전략을 고려

로컬 챗봇을 만들 때 흔히 하는 실수는 대화 전체를 매번 프롬프트에 붙이는 방식입니다. 이렇게 하면 prompt_len이 계속 증가하고, KV 캐시도 매 턴 커져서 결국 OOM이 납니다.

대안은 두 가지입니다.

  • 요약을 주기적으로 만들어 컨텍스트를 압축
  • 최근 N 토큰만 유지하는 슬라이딩 윈도우

슬라이딩 윈도우는 모델이 기본적으로 지원하는 경우도 있지만, 애플리케이션 레벨에서 “최근 메시지만 유지”해도 효과가 큽니다.

def build_context(messages, max_chars=6000):
    # 아주 단순한 예시: 최근 메시지부터 누적해서 길이 제한
    ctx = ""
    for m in reversed(messages):
        candidate = f"{m['role']}: {m['content']}\n" + ctx
        if len(candidate) > max_chars:
            break
        ctx = candidate
    return ctx

messages = [
    {"role": "user", "content": "로컬 LLM이 자꾸 OOM 나요"},
    {"role": "assistant", "content": "대부분 KV 캐시가 원인입니다"},
]

prompt = build_context(messages) + "assistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

문자 수 제한은 토큰 제한과 정확히 일치하진 않지만, “컨텍스트 무한 증가”를 막는 1차 안전장치로는 유효합니다.

3) 메모리 피크를 줄이는 로딩/런타임 설정

4bit + KV 캐시 외에도, 로컬 환경에서 OOM을 줄이는 실전 옵션들이 있습니다.

3-1) torch.inference_mode()eval()은 기본

추론인데도 그래프가 남거나 드롭아웃 등이 켜져 있으면 불필요한 메모리 사용이 생깁니다.

  • model.eval()
  • with torch.inference_mode():

이 두 개는 거의 필수입니다.

3-2) device_map="auto" + CPU 오프로딩은 양날의 검

accelerate의 자동 배치는 VRAM이 부족할 때 CPU로 일부를 내립니다. 로딩은 성공하지만,

  • 속도가 크게 느려지고
  • CPU RAM과 PCIe 전송이 병목이 되며
  • 특정 순간 전송 버퍼로 피크가 튈 수 있습니다.

가능하면 “모델은 GPU에 다 올리되(4bit로)”, KV 캐시와 시퀀스를 관리하는 쪽이 체감이 좋습니다.

3-3) PyTorch 메모리 단편화 완화

길게 실행하는 서버에서는 메모리 단편화로 인해 “남은 VRAM이 있는데도” OOM이 날 수 있습니다. 다음 환경 변수를 고려할 수 있습니다.

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

또는 코드에서 캐시를 주기적으로 정리하는 방법도 있습니다(과하면 성능 저하).

import torch

def maybe_cleanup(step, every=50):
    if step % every == 0:
        torch.cuda.empty_cache()

4) OOM 디버깅: 어디서 터지는지 수치로 확인하기

OOM 최적화는 감으로 하면 시간이 오래 걸립니다. 최소한 아래 두 가지는 찍어보는 게 좋습니다.

  • 현재 할당/예약 VRAM
  • 프롬프트 토큰 수, 생성 토큰 수
import torch

def vram_report(tag=""):
    if not torch.cuda.is_available():
        return
    allocated = torch.cuda.memory_allocated() / (1024**3)
    reserved = torch.cuda.memory_reserved() / (1024**3)
    print(f"[{tag}] allocated={allocated:.2f}GB reserved={reserved:.2f}GB")

vram_report("after_load")

prompt = "설명: KV 캐시가 OOM을 유발하는 이유"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
print("prompt_tokens=", inputs["input_ids"].shape[-1])

with torch.inference_mode():
    vram_report("before_generate")
    out = model.generate(**inputs, max_new_tokens=512, use_cache=True)
    vram_report("after_generate")

여기서 max_new_tokens를 128, 256, 512로 바꿔가며 증가폭을 보면, “OOM이 가중치 때문인지 KV 캐시 때문인지”가 즉시 드러납니다.

5) 추천 조합: 로컬 LLM을 안정적으로 굴리는 프리셋

아래는 “일단 잘 돌아가게” 만드는 현실적인 프리셋입니다.

  • 가중치: 4bit nf4 + double quant
  • compute dtype: 가능하면 bf16
  • 생성: max_new_tokens를 보수적으로 시작(예: 256)
  • 대화: 슬라이딩 윈도우 또는 요약으로 프롬프트 무한 증가 방지
  • 실행: inference_mode, eval

예시 코드(위 요소를 한 번에 묶음):

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "meta-llama/Llama-2-7b-hf"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
)
model.eval()

def chat_once(user_text: str, max_new_tokens: int = 256):
    prompt = f"user: {user_text}\nassistant:"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            use_cache=True,
        )
    return tokenizer.decode(out[0], skip_special_tokens=True)

print(chat_once("OOM을 줄이는 핵심만 요약해줘"))

마무리: 4bit는 시작이고, 승부는 KV 캐시에서 난다

로컬에서 LLM OOM을 잡는 가장 빠른 길은 다음 순서로 접근하는 것입니다.

  1. 가중치부터 4bit로 줄여 로딩 여유를 확보
  2. KV 캐시가 커지는 조건을 통제: max_new_tokens, 대화 컨텍스트 증가, 캐시 구현
  3. 실측 기반으로 튜닝: VRAM 리포트 + 토큰 길이

이 흐름대로 하면 “모델은 뜨는데 생성에서 죽는” 상태에서 벗어나, 같은 GPU에서도 더 긴 컨텍스트와 더 안정적인 서비스를 만들 수 있습니다.

추가로, 로컬 추론을 서비스로 감싸 배포/자동화까지 가는 단계에서는 CI 캐시 전략도 중요해집니다. 빌드/배포가 느려서 반복 튜닝이 막힐 때는 GitHub Actions 캐시 무효화로 빌드 느림 해결 같은 글이 의외로 도움이 됩니다.