Published on

Transformers 로컬 LLM OOM 해결 - 4bit+KV캐시

Authors

로컬 GPU(혹은 CPU)에서 Transformers로 LLM을 띄우다 보면, 모델 로딩은 되는데 생성 중에 터지거나(특히 긴 프롬프트/긴 출력), 아예 로딩 단계에서 CUDA out of memory가 나는 경우가 많습니다. 이 글에서는 OOM을 가중치(Weights) 문제와 KV 캐시(Key/Value cache) 문제로 분해해서 접근하고, 가장 효과가 큰 조합인 4bit 양자화 + KV 캐시 관리로 해결하는 방법을 코드 중심으로 정리합니다.

실무적으로는 “4bit로 줄였는데도 왜 OOM이지?”가 핵심인데, 그 답은 대부분 KV 캐시가 시퀀스 길이에 비례해 계속 커지기 때문입니다. 즉, 가중치를 줄이는 것만으로는 부족할 수 있습니다.

OOM을 두 덩어리로 나눠 생각하기

1) 가중치 메모리(모델 파라미터)

  • FP16/BF16로 로드하면 파라미터가 그대로 GPU 메모리를 점유합니다.
  • 7B급 모델도 FP16이면 대략 수십 GB까지 갈 수 있어 단일 소비자 GPU에서 바로 한계가 옵니다.
  • 해결책: 8bit/4bit 양자화, 오프로딩, 더 작은 모델, 텐서 병렬 등.

2) KV 캐시 메모리(생성 중 누적)

  • 디코더 기반 LLM은 토큰을 생성할 때 각 레이어별로 K/V를 저장해 재사용합니다.
  • KV 캐시는 대략적으로 배치 크기, 레이어 수, hidden size, 시퀀스 길이에 비례해 증가합니다.
  • “프롬프트가 길다 + max_new_tokens가 크다” 조합이면, 가중치가 4bit여도 KV 캐시로 OOM이 납니다.
  • 해결책: 입력 길이 제한, 출력 토큰 제한, 배치/동시성 제한, KV 캐시 양자화, sliding window, FlashAttention 계열, 캐시 정책 조정.

빠른 진단 체크리스트

아래 질문에 답하면 원인이 거의 좁혀집니다.

  1. 로딩 단계에서 OOM인가?
    • 예: 가중치가 큼. 4bit/8bit, device map, offload를 먼저.
  2. 생성 중간(몇 토큰 생성 후) OOM인가?
    • 예: KV 캐시가 커짐. max_new_tokens, 입력 길이, KV 캐시 최적화가 핵심.
  3. 동시에 여러 요청을 처리하나?
    • 동시성은 KV 캐시를 요청 수만큼 곱합니다.

메모리 문제 접근 방식은 JVM 힙 튜닝과도 유사합니다. “무조건 더 큰 머신”이 아니라, 어떤 객체가 커지는지(여기서는 weights vs KV cache)를 먼저 분리해야 합니다. 메모리 진단 관점은 Spring Boot OutOfMemoryError 덤프 분석·튜닝 7단계 글의 사고방식과도 연결됩니다.

4bit 양자화로 가중치 메모리부터 줄이기

Transformers에서 가장 보편적인 선택은 bitsandbytes 기반 4bit 로딩입니다. 핵심은 BitsAndBytesConfig(load_in_4bit=True)와 적절한 compute dtype을 고르는 것입니다.

설치

pip install -U "transformers>=4.40" accelerate bitsandbytes

4bit 로딩 예시

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_id = "meta-llama/Llama-2-7b-chat-hf"  # 예시

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # GPU가 bf16 지원하면 권장
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

model.eval()

옵션 해설(실전 기준)

  • nf4: 4bit 중 품질/안정성 밸런스가 좋아 가장 흔히 씁니다.
  • double_quant: 추가 압축으로 메모리를 더 줄이는 편.
  • bnb_4bit_compute_dtype: 연산 dtype. bf16이 가능하면 fp16보다 수치적으로 안정적인 경우가 많습니다.
  • device_map="auto": GPU에 못 올리는 부분을 CPU로 분산할 수 있지만, CPU 오프로딩은 지연이 커질 수 있습니다.

여기까지 하면 “로딩 OOM”은 상당수 해결됩니다. 하지만 생성 중 OOM은 여전히 남을 수 있습니다.

KV 캐시가 왜 OOM의 주범이 되는가

생성은 대략 다음을 반복합니다.

  1. 입력 토큰을 넣고 forward
  2. 다음 토큰 샘플링
  3. KV 캐시에 K/V를 누적 저장
  4. 다음 스텝에서 KV 캐시를 재사용

즉, 시퀀스 길이가 늘어날수록 캐시가 커집니다. 특히 아래 상황에서 급격히 터집니다.

  • 긴 시스템 프롬프트 + 긴 대화 히스토리
  • max_new_tokens를 크게 잡음
  • 여러 요청을 동시에 처리(서버 형태)

가장 먼저 먹히는 처방: 입력/출력 길이 제한

토크나이즈 단계에서 입력 길이 컷

prompt = """...긴 프롬프트..."""

inputs = tokenizer(
    prompt,
    return_tensors="pt",
    truncation=True,
    max_length=2048,  # GPU 메모리에 맞춰 조정
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

생성 파라미터로 출력 길이 제한

with torch.no_grad():
    out = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        use_cache=True,
    )

print(tokenizer.decode(out[0], skip_special_tokens=True))
  • max_new_tokens를 줄이는 것만으로도 KV 캐시 증가를 즉시 제한합니다.
  • 서버 환경에서는 사용자별 상한을 두는 게 안정적입니다.

KV 캐시 최적화 1: 캐시 구현/정책 점검

Transformers는 버전과 모델에 따라 캐시 구현이 다를 수 있습니다. 가능한 경우 다음을 확인합니다.

  • use_cache=True는 속도에는 유리하지만 메모리를 먹습니다.
  • 반대로 use_cache=False는 메모리는 줄지만, 매 토큰마다 전체 시퀀스를 다시 계산해 속도가 크게 느려집니다. 로컬 실험에서만 임시 회피로 고려하세요.
out = model.generate(
    **inputs,
    max_new_tokens=128,
    use_cache=True,  # 기본값이지만 명시
)

만약 특정 작업이 “한 번만 짧게 생성”이면 use_cache=False로 버티는 경우도 있지만, 일반적인 채팅/서빙에서는 권장하지 않습니다.

KV 캐시 최적화 2: FlashAttention 계열로 메모리/속도 개선

가능하다면 FlashAttention2 같은 최적화 커널을 쓰면 attention 메모리 사용량과 속도가 개선되는 경우가 많습니다. 모델/환경에 따라 적용 방식이 다르지만, Transformers에서는 모델 로딩 시 attn_implementation 설정을 제공하는 경우가 있습니다.

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",  # 지원하는 모델/환경에서
)

주의할 점

  • GPU 아키텍처/드라이버/CUDA/파이토치 버전 제약이 있습니다.
  • 지원하지 않는 모델에서는 에러가 날 수 있으니, 실패 시 기본 구현으로 폴백하세요.

KV 캐시 최적화 3: KV 캐시 양자화(가능한 경우)

“가중치는 4bit인데 KV 캐시는 fp16/bf16”이라서 OOM이 나는 케이스가 많습니다. 최근에는 KV 캐시 자체를 더 낮은 비트로 저장하는 접근(예: 8bit KV cache)이 생태계에서 확산 중입니다.

다만 이 부분은 Transformers 단독 옵션만으로 일괄 해결되기보다,

  • 모델 구현체
  • inference 엔진(vLLM, TensorRT-LLM 등)
  • 커스텀 커널 에 따라 지원 여부가 갈립니다.

Transformers만으로 해결이 어려우면, 서빙을 목표로 할 때는 vLLM 같은 엔진을 검토하는 것이 현실적인 경우가 많습니다(특히 동시 요청이 있을 때).

실전: “4bit + 메모리 안전한 생성” 레시피

아래는 로컬에서 OOM을 덜 내는 방향으로 설정을 모아둔 예시입니다.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_id = "mistralai/Mistral-7B-Instruct-v0.2"  # 예시

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.eval()

def generate_safe(prompt: str, max_input_tokens: int = 2048, max_new_tokens: int = 256):
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=max_input_tokens,
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.05,
            use_cache=True,
        )

    return tokenizer.decode(out[0], skip_special_tokens=True)

print(generate_safe("Explain KV cache in LLM inference."))

이 레시피의 의도는 단순합니다.

  • 4bit로 가중치 메모리를 먼저 줄이고
  • 입력/출력 토큰 상한으로 KV 캐시 폭증을 막습니다.

그래도 OOM이면: 자주 놓치는 원인 6가지

1) 프롬프트 템플릿이 생각보다 길다

챗 템플릿(system/user/assistant 태그)이 누적되면 입력 토큰이 급증합니다. 토큰 카운트를 로그로 찍어보세요.

enc = tokenizer(prompt)
print("input_tokens=", len(enc["input_ids"]))

2) 배치 처리 또는 동시성

서버로 띄우면 요청 수만큼 KV 캐시가 쌓입니다. 단일 GPU에서는 동시성 제한이 가장 강력한 안전장치입니다.

3) max_new_tokens가 과도하다

512, 1024 같은 값은 KV 캐시를 크게 만듭니다. “기본 256, 필요 시 늘리기”가 운영에서 안전합니다.

4) torch.compile 혹은 그래프 캡처로 메모리 패턴이 바뀜

성능 튜닝 옵션이 메모리 사용량을 증가시키는 경우가 있습니다. 먼저 안정화 후 튜닝하세요.

5) CPU 오프로딩으로 VRAM은 줄었지만 RAM이 터짐

device_map="auto"가 CPU로 밀어내면 시스템 RAM이 부족해질 수 있습니다.

6) PyTorch 캐시로 “남아있는 것처럼 보이는” VRAM

PyTorch는 메모리를 캐시합니다. 실제 누수와 구분해야 합니다.

import torch
print(torch.cuda.memory_summary())
# 필요 시
torch.cuda.empty_cache()

운영 관점 팁: OOM을 “장애”로 만들지 않기

  • 요청당 토큰 상한을 강제하고, 초과 시 요약/히스토리 축약을 적용합니다.
  • 동시 요청 수를 제한하거나 큐잉합니다.
  • 장애 대응 Runbook을 만듭니다. 인프라에서 흔한 403/노드 이슈처럼, LLM 서빙도 “반복되는 장애 패턴”이 금방 생깁니다. 장애를 단계적으로 좁혀가는 방식은 gRPC 마이크로서비스 503·데드라인 초과 디버깅 같은 글의 접근과도 유사합니다.

결론

Transformers 로컬 LLM OOM은 대부분

  • 가중치가 커서(로딩 단계 OOM)
  • KV 캐시가 커서(생성 중 OOM) 발생합니다.

가장 효과적인 1차 해법은

  • 4bit 양자화로 가중치를 줄이고
  • 입력/출력 토큰 상한으로 KV 캐시 성장을 제한하는 것입니다.

그 다음 단계로는

  • FlashAttention 계열 적용
  • KV 캐시 양자화/서빙 엔진 전환
  • 동시성 제어 를 통해 “로컬에서 안정적으로 길게 생성”하는 구성을 만들 수 있습니다.

이 글의 코드 레시피를 기준으로, 본인 GPU 메모리(예: 8GB, 12GB, 24GB)에 맞춰 max_lengthmax_new_tokens를 먼저 조정해보면 OOM 재현과 해결이 가장 빠르게 됩니다.