Published on

Transformers 로컬 LLM OOM - 4bit·KV캐시 최적화

Authors

로컬 GPU에서 transformers로 LLM을 돌리다 보면, "모델은 로드되는데 생성 시작하자마자 OOM" 혹은 "프롬프트 길이가 조금만 늘어도 터짐" 같은 증상을 자주 만납니다. 이 문제는 단순히 VRAM이 부족해서가 아니라, 가중치(Weights)KV 캐시(Key/Value Cache), 그리고 활성값(Activations) 이 어떤 비율로 메모리를 잡아먹는지 이해하면 훨씬 체계적으로 해결할 수 있습니다.

이 글에서는 (1) OOM을 일으키는 메모리 구성 요소를 나눠 보고, (2) bitsandbytes 4bit 로드로 가중치 메모리를 줄이고, (3) KV 캐시를 줄이거나 효율적으로 쓰는 옵션들(정확히는 transformers에서 가능한 범위)을 정리합니다. 속도 최적화가 목적이라면 FlashAttention2도 함께 고려하세요: Transformers 로컬 LLM 속도 2배 - FlashAttention2 적용

로컬 LLM OOM의 3대 원인: 가중치·KV 캐시·활성값

1) 모델 가중치(Weights)

가중치는 모델 로드 순간부터 VRAM을 상시 점유합니다.

  • FP16/BF16: 파라미터당 2바이트
  • FP32: 파라미터당 4바이트
  • 8bit/4bit: 더 작지만, 일부 연산용 버퍼가 추가될 수 있음

예를 들어 7B 모델을 FP16으로 올리면 대략 7e9 * 2 bytes ≒ 14GB 수준이 기본으로 깔립니다(정확한 값은 아키텍처/패딩/버퍼에 따라 달라짐). 그래서 16GB GPU에서 7B가 “간당간당”해지는 이유가 여기 있습니다.

2) KV 캐시(K/V Cache)

KV 캐시는 생성(autoregressive decoding) 중에 과거 토큰의 attention key/value를 저장해 두는 메모리입니다. 프롬프트가 길고, 생성 길이가 길수록 선형으로 증가합니다.

대략적인 감으로는 다음 요소에 비례합니다.

  • batch_size
  • num_layers
  • num_heads 및 head dimension
  • sequence_length(프롬프트 길이 + 생성된 길이)
  • KV dtype(보통 FP16/BF16)

즉, 4bit로 가중치를 줄여도 긴 컨텍스트에서 OOM이 나는 경우가 흔한데, 그때 범인은 KV 캐시인 경우가 많습니다.

3) 활성값(Activations)과 임시 버퍼

추론에서도 연산 중간 텐서/워크스페이스가 생깁니다.

  • torch.compile/커널/어텐션 구현에 따라 워크스페이스가 커질 수 있음
  • 배치가 커지면 활성값도 커짐
  • 샘플링 옵션 자체는 큰 영향을 주지 않지만, use_cache와 시퀀스 길이가 핵심

먼저 진단: 지금 OOM은 어디서 나는가

OOM을 줄이기 전에, “로딩에서 터지는지 vs 생성에서 터지는지”를 나눠야 합니다.

  • 로딩에서 OOM: 가중치/로더 버퍼 문제 가능성이 큼 → 4bit/8bit, device_map, 오프로딩
  • 생성 시작/중간에 OOM: KV 캐시(컨텍스트/생성 길이/배치) 가능성이 큼 → max_new_tokens, 프롬프트 길이, 캐시 전략

간단한 VRAM 체크 유틸을 하나 두면 감이 빨리 옵니다.

import torch

def vram(tag=""):
    if not torch.cuda.is_available():
        print("CUDA not available")
        return
    torch.cuda.synchronize()
    alloc = torch.cuda.memory_allocated() / 1024**2
    reserv = torch.cuda.memory_reserved() / 1024**2
    max_alloc = torch.cuda.max_memory_allocated() / 1024**2
    print(f"[{tag}] allocated={alloc:.1f}MiB reserved={reserv:.1f}MiB max_alloc={max_alloc:.1f}MiB")

bitsandbytes 4bit로 가중치 메모리 크게 줄이기

bitsandbytes의 4bit 로딩은 로컬 단일 GPU에서 가장 효과가 큰 “첫 번째 카드”입니다. 핵심은 BitsAndBytesConfig로 4bit 양자화 옵션을 명시하는 것입니다.

권장 설정: NF4 + double quant + BF16 compute

  • bnb_4bit_quant_type="nf4": 4bit 중에서도 품질/안정성이 좋은 편
  • bnb_4bit_use_double_quant=True: 2단 양자화로 추가 절감
  • bnb_4bit_compute_dtype=torch.bfloat16 또는 torch.float16: 연산 dtype
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_id = "meta-llama/Llama-2-7b-chat-hf"  # 예시

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

model.eval()

자주 나는 함정

  • GPU가 BF16을 잘 지원하지 않으면 torch.float16로 바꾸는 게 안전합니다.
  • device_map="auto"는 VRAM이 부족하면 일부 레이어를 CPU로 내릴 수 있는데, 이 경우 속도가 크게 느려질 수 있습니다. 그래도 “일단 OOM을 피하는” 목적에는 유효합니다.

4bit인데도 OOM이면: 로딩이 아니라 KV 캐시일 가능성

4bit 로딩은 가중치에만 직접적으로 큰 영향을 줍니다. 프롬프트가 길거나 max_new_tokens가 크면 KV 캐시가 커져서 여전히 OOM이 납니다.

따라서 다음 섹션의 KV 캐시 최적화가 중요합니다.

KV 캐시 최적화: 컨텍스트·생성 길이·배치부터 줄여라

1) 가장 확실한 처방: max_new_tokens와 입력 길이 제한

KV 캐시는 “지금까지의 토큰 수”에 선형 비례합니다. 즉, 아래 두 값을 줄이는 것이 가장 즉효입니다.

  • 입력 프롬프트 토큰 수
  • 생성 토큰 수(max_new_tokens)
import torch

prompt = """너는 로컬 LLM 튜너다. 아래 로그를 보고 OOM 원인을 분석해줘..."""

inputs = tokenizer(
    prompt,
    return_tensors="pt",
    truncation=True,
    max_length=2048,  # 컨텍스트 상한
).to(model.device)

with torch.inference_mode():
    out = model.generate(
        **inputs,
        max_new_tokens=256,  # 생성 길이 상한
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        use_cache=True,
    )

print(tokenizer.decode(out[0], skip_special_tokens=True))

운영/제품 관점에서는 입력 길이 제한 + 요약/리트리벌 조합이 가장 현실적입니다. 메모리를 무한정 늘리는 대신, “필요한 정보만 컨텍스트에 넣는” 전략이 결국 비용과 안정성을 동시에 잡습니다. (에이전트/메모리 관리 관점은 AutoGPT 메모리 폭주? 벡터DB TTL로 안정화도 참고할 만합니다.)

2) 배치 크기(batch_size)는 KV 캐시에 직격탄

동시 요청을 배치로 묶으면 처리량은 좋아질 수 있지만, KV 캐시가 배치에 비례해 늘어납니다.

  • OOM이 잦으면 우선 배치를 1로 낮추고 안정 구간을 찾은 뒤, 점진적으로 올리세요.
  • 스트리밍 응답을 한다면 “동시성”을 배치로 해결하기보다 큐잉/레이트리밋을 고려하는 편이 안전합니다.

3) use_cache=False는 메모리를 줄이지만 속도를 크게 희생

KV 캐시를 끄면 토큰을 생성할 때마다 이전 토큰을 매번 다시 계산해야 해서, 일반적으로 속도가 급격히 느려집니다. 하지만 “OOM을 피해야만 하는” 상황에서는 최후의 수단이 될 수 있습니다.

with torch.inference_mode():
    out = model.generate(
        **inputs,
        max_new_tokens=128,
        use_cache=False,  # KV 캐시 비활성화
    )

실무적으로는 use_cache=False는 디버깅/응급 처치에 가깝고, 보통은 컨텍스트/생성 길이/배치를 조정하는 쪽이 낫습니다.

Transformers에서의 캐시 전략: 무엇이 실제로 도움이 되나

1) 캐시 구현(동적/정적 등)과 메모리

최근 transformers는 캐시 표현을 개선해 왔지만, 로컬 단일 GPU OOM에 대해 “마법처럼” 줄여주는 옵션은 제한적입니다. 그래도 다음은 체크할 가치가 있습니다.

  • 최신 transformers로 업데이트(캐시/어텐션 최적화가 포함되는 경우가 많음)
  • 가능하면 더 효율적인 어텐션 커널 사용(예: FlashAttention2)

FlashAttention2는 주로 속도/워크스페이스에 이점이 있고, 케이스에 따라 메모리에도 도움을 줄 수 있습니다. 적용 방법은 별도 글에 정리했습니다: Transformers 로컬 LLM 속도 2배 - FlashAttention2 적용

2) 긴 컨텍스트가 필요하면: 아키텍처/서빙 스택을 바꿔라

transformers 단일 프로세스로 긴 컨텍스트 + 다중 동시성을 감당하려 하면 KV 캐시가 병목이 됩니다. 이때는 다음 선택지가 더 현실적입니다.

  • vLLM 같은 KV 캐시 관리 최적화 엔진 사용(PagedAttention 등)
  • 서버리스/오토스케일과 결합해 피크를 흡수

서빙 관점에서의 구성은 vLLM+KServe로 LLM 서버리스 배포와 콜드스타트 최소화에서 더 깊게 다룹니다.

OOM을 줄이는 “현실적인” 체크리스트

1) 환경/버전

  • torch, transformers, accelerate, bitsandbytes를 가능한 최신 호환 조합으로
  • NVIDIA 드라이버/CUDA 호환 확인

2) 모델 로딩

  • 1순위: bitsandbytes 4bit(NF4 + double quant)
  • device_map="auto"로 일단 띄우고, 속도가 문제면 더 큰 VRAM 또는 더 작은 모델로

3) 생성 설정

  • 입력 토큰 상한(max_length)을 명시
  • max_new_tokens를 보수적으로
  • 배치 1부터 시작
  • 정말 급하면 use_cache=False

4) PyTorch 메모리 파편화(단편화) 완화

긴 시간 실행하면 “할당은 가능한데 연속 공간이 부족” 같은 형태로 OOM이 날 수 있습니다. 완전한 해결은 어렵지만, 다음이 도움이 되는 경우가 있습니다.

  • 주기적으로 프로세스를 재시작(가장 확실)
  • 실험/서빙을 분리(노트북에서 계속 로드/언로드 반복하지 않기)
  • 필요 시 torch.cuda.empty_cache()는 “캐시 반환”이라 효과가 제한적이지만, 특정 워크로드에서는 숨통이 트일 때가 있음
import torch

def cleanup():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

예제: 4bit + 보수적 KV 캐시 운영 템플릿

아래는 로컬에서 “일단 안 터지게” 운영하기 위한 최소 템플릿입니다.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "mistralai/Mistral-7B-Instruct-v0.2"  # 예시

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
)
model.eval()

def generate(prompt: str, max_input_tokens: int = 2048, max_new_tokens: int = 256):
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=max_input_tokens,
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            use_cache=True,
        )
    return tokenizer.decode(out[0], skip_special_tokens=True)

print(generate("OOM 안 나게 설정 포인트만 요약해줘."))

결론: 4bit는 시작, 승부는 KV 캐시에서 난다

로컬 LLM OOM을 빠르게 줄이는 순서는 보통 다음이 가장 효율적입니다.

  1. bitsandbytes 4bit로 가중치 메모리를 먼저 낮춘다.
  2. 그래도 터지면 대부분 KV 캐시 문제이므로, 입력 길이/생성 길이/배치를 줄인다.
  3. 긴 컨텍스트와 동시성이 필요하면 transformers 단독 최적화에 집착하기보다 vLLM 같은 엔진/서빙 구조로 전환한다.

이 3단계를 기준으로 접근하면, “어제는 되던 게 오늘은 터지는” 식의 시행착오를 크게 줄이면서, 로컬에서도 안정적으로 LLM을 운용할 수 있습니다.