Published on

Transformers 로컬 LLM OOM - 8bit·KV 캐시 최적화

Authors

로컬 GPU에서 transformers로 LLM을 띄우다 보면, 모델 로딩은 되는데 생성 단계에서 갑자기 OOM이 나거나, 반대로 로딩 단계에서 바로 터지는 일이 흔합니다. 특히 “--max_new_tokens를 조금만 올려도 죽는다” 같은 증상은 대개 **KV 캐시(어텐션 캐시)**가 VRAM을 선형으로 잡아먹기 때문입니다.

이 글은 다음 두 축으로 OOM을 줄이는 방법을 정리합니다.

  • 가중치 메모리 줄이기: 8bit(또는 4bit) 로딩, dtype 전략
  • 런타임 메모리 줄이기: KV 캐시/컨텍스트 길이, 배치/스트리밍, 캐시 구현 선택

실무에서 “일단 되게 만들기”가 아니라, 어디서 VRAM이 새는지를 감 잡을 수 있게 수치와 체크리스트 중심으로 설명하겠습니다.

OOM이 나는 지점부터 구분하자

OOM은 크게 3군데에서 납니다.

  1. 모델 로딩 시점: 가중치가 VRAM에 못 올라감
  2. 프리필(prefill) 시점: 긴 프롬프트를 한 번에 넣는 단계에서 어텐션/활성화가 폭증
  3. 디코딩(decode) 시점: 토큰을 생성할수록 KV 캐시가 누적되어 터짐

증상만 보면 비슷하지만 처방이 다릅니다.

  • 로딩 OOM이면: 8bit/4bit, device_map, CPU offload
  • prefill OOM이면: 입력 길이(컨텍스트) 줄이기, batch_size 줄이기
  • decode OOM이면: max_new_tokens/컨텍스트/use_cache/KV 캐시 타입 최적화

먼저 현재 VRAM 사용량을 “눈으로” 확인하기

파이토치는 캐시 allocator 때문에 “사용 중”과 “예약됨”이 달라서 헷갈립니다. 아래처럼 찍어보면 원인을 분리하기 쉽습니다.

import torch

def vram(msg: str = ""):
    if not torch.cuda.is_available():
        print("CUDA not available")
        return
    allocated = torch.cuda.memory_allocated() / 1024**2
    reserved = torch.cuda.memory_reserved() / 1024**2
    print(f"[{msg}] allocated={allocated:.1f}MB reserved={reserved:.1f}MB")

vram("start")
# 모델 로딩
vram("after load")
# 프롬프트 토큰화/프리필
vram("after prefill")
# 생성 후
vram("after generate")

reserved가 계속 커지는데 allocated가 덜 늘면, 조각화(fragmentation)나 캐시 정책 이슈일 수 있습니다. 반대로 allocated가 토큰 생성에 따라 선형으로 증가하면 KV 캐시가 범인일 확률이 높습니다.

8bit 로딩: 가장 쉬운 “가중치 메모리” 절감

bitsandbytes로 8bit 로딩하기

bitsandbytes 8bit는 가중치를 INT8로 들고 있으면서 연산은 혼합 정밀도로 처리합니다. 16bit 대비 가중치 VRAM이 크게 줄어 “로딩 OOM”을 먼저 잡는 데 효과적입니다.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-2-7b-chat-hf"  # 예시

tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_8bit=True,
    device_map="auto",
    torch_dtype=torch.float16,
)

model.eval()

포인트:

  • device_map="auto"는 레이어를 GPU/CPU로 자동 분산합니다. VRAM이 애매할 때 생존 확률이 올라갑니다.
  • torch_dtype는 모델/커널 선택에 영향을 줍니다. GPU가 bfloat16에 강하면 torch.bfloat16도 고려하세요.

8bit로도 부족하면: 4bit(QLoRA 스타일)까지

8bit로 로딩은 되는데 생성에서 터진다면, 사실 가중치보다 KV 캐시가 문제일 수 있습니다. 그래도 VRAM이 빡빡하면 4bit도 옵션입니다.

from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
)

주의:

  • 4bit는 속도/품질/호환성 트레이드오프가 있습니다.
  • “로딩은 되는데 느려졌다”는 흔한 부작용입니다.

진짜 주범: KV 캐시가 왜 OOM을 만드는가

디코딩 단계에서 모델은 매 토큰마다 과거 토큰의 K/V를 저장해 두고(캐시), 다음 토큰을 빠르게 계산합니다. 이 캐시는 대략 다음에 비례합니다.

  • 레이어 수
  • 헤드 수/헤드 차원
  • 배치 크기
  • 현재까지의 시퀀스 길이(프롬프트 길이 + 생성된 토큰)
  • dtype(대개 fp16 또는 bf16)

즉, max_new_tokens를 올리면 캐시가 계속 누적되어 VRAM이 선형 증가합니다.

가장 확실한 방법: 컨텍스트와 생성 길이를 줄이기

  • 입력 프롬프트 길이를 줄이기: 불필요한 로그/코드/JSON을 그대로 넣지 않기
  • max_new_tokens를 제한하기
  • 여러 요청을 묶는 배치 생성을 피하기(배치가 커지면 KV 캐시도 배수로 커짐)

예시:

inputs = tok(
    "질문: ...\n답변:",
    return_tensors="pt",
    truncation=True,
    max_length=2048,
).to(model.device)

out = model.generate(
    **inputs,
    max_new_tokens=256,
    do_sample=False,
)

여기서 max_length는 “입력 토큰” 상한이고, max_new_tokens는 “생성 토큰” 상한입니다. 둘 다 KV 캐시에 영향을 줍니다.

use_cache=False는 최후의 수단

KV 캐시를 끄면 메모리는 줄지만, 디코딩이 매번 과거 전체를 다시 보게 되어 속도가 크게 떨어집니다. 그래도 “일단 OOM만 피하자”면 선택지입니다.

out = model.generate(
    **inputs,
    max_new_tokens=128,
    use_cache=False,
)

실전에서는 보통 use_cache=True를 유지하고, 아래의 “캐시 구현 최적화”로 해결하는 편이 낫습니다.

Transformers의 KV 캐시 구현을 바꿔서 메모리/속도 최적화

Transformers는 버전이 올라가며 캐시 구현이 다양해졌고, 일부는 메모리 효율이 더 좋습니다. 환경에 따라 차이가 크니 “되는 조합”을 찾는 게 중요합니다.

SDPA(Scaled Dot-Product Attention) 활성화

PyTorch 2.x의 SDPA는 커널 최적화로 메모리와 속도를 개선할 수 있습니다.

import torch
from transformers import AutoModelForCausalLM

torch.backends.cuda.matmul.allow_tf32 = True

a = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
)

모델에 따라 attn_implementation 지원 여부가 다릅니다. 지원하지 않으면 경고/에러가 납니다.

FlashAttention 계열 사용(가능한 경우)

FlashAttention은 어텐션 계산을 더 메모리 효율적으로 합니다. 다만 설치/호환성이 까다롭고, GPU/드라이버/파이토치 버전에 따라 편차가 큽니다.

Transformers가 지원하는 경우:

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
)

여기서도 핵심은 “KV 캐시 자체 크기”를 0으로 만들지는 못하지만, 프리필/디코딩에서의 메모리 피크를 낮춰 OOM 경계에서 살려주는 경우가 많습니다.

CPU 오프로딩과 device_map으로 “어떻게든 띄우기”

VRAM이 부족하면 일부 레이어를 CPU로 넘기는 전략이 있습니다. 속도는 느려지지만, 개발/테스트 환경에서는 유용합니다.

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_8bit=True,
    device_map="auto",
)

device_map="auto"만으로 부족하면 accelerate의 오프로딩 설정(예: 오프로드 폴더)을 추가로 고려할 수 있습니다. 다만 디스크 I/O가 병목이 될 수 있어, NVMe가 아니면 체감이 큽니다.

자주 놓치는 OOM 원인 6가지 체크리스트

1) 배치/동시성: “요청 2개”가 곧 KV 캐시 2배

서빙 코드에서 동시 요청을 처리하면, 각 요청의 KV 캐시가 동시에 쌓입니다. 로컬 테스트에서는 한 번에 한 요청만 처리하도록 제한하는 게 안전합니다.

동시성/스트리밍으로 인한 메모리 누수/중복 토큰 문제는 별개로도 자주 발생합니다. 스트리밍 파이프라인을 쓰고 있다면 LangChain 스트리밍 중복토큰·메모리누수 9분 해결도 함께 확인해 두면 좋습니다.

2) max_new_tokens를 무작정 크게 잡음

“최대 2048 토큰 생성” 같은 설정은 KV 캐시를 폭발시키기 쉽습니다. UI에서 길이를 열어두더라도 서버 내부 상한을 두고, 길어지면 요약/청킹으로 우회하는 편이 낫습니다.

대용량 입력을 청크로 나누는 패턴은 LLM API에서도 동일합니다. 로컬 모델이라도 입력을 잘게 나누는 사고방식이 도움이 되며, 관련해서는 OpenAI Responses API 413 에러 업로드 용량 제한과 청크 전략을 참고할 수 있습니다.

3) 프롬프트에 불필요한 히스토리를 계속 누적

대화 히스토리를 매번 전부 넣으면 컨텍스트가 금방 4k, 8k를 넘어가고, prefill에서 피크가 커집니다. “최근 N턴만 유지” 또는 “요약 메모리”를 도입하세요.

4) torch.no_grad() / inference_mode() 누락

학습 그래프가 남으면 메모리가 급증합니다.

with torch.inference_mode():
    out = model.generate(**inputs, max_new_tokens=128)

5) 토크나이저 패딩/정렬로 시퀀스가 불필요하게 길어짐

배치 토크나이즈 시 padding=True는 가장 긴 시퀀스에 맞춰 패딩되어 KV 캐시 비용이 커질 수 있습니다. 로컬 단일 요청이면 패딩을 최소화하세요.

6) 메모리 해제가 안 된 것처럼 보임

파이토치는 메모리 풀을 유지합니다. “사용량이 안 내려간다”가 곧 누수는 아닙니다. 다만 요청 단위로 객체를 오래 잡고 있으면 실제 누수가 됩니다.

서빙 코드에서 리소스 생명주기를 명확히 하고 싶다면, 파이썬 컨텍스트 매니저를 활용해 누수를 줄이는 패턴이 도움이 됩니다. 관련해서는 Python 3.11+ asynccontextmanager로 리소스 누수 막기도 참고할 만합니다.

실전 예제: 8bit + SDPA + 길이 제한으로 OOM 방어

아래는 “대부분의 단일 GPU 로컬 환경”에서 무난히 출발하기 좋은 조합입니다.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Llama-2-7b-chat-hf"  # 예시

torch.backends.cuda.matmul.allow_tf32 = True

tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_8bit=True,
    device_map="auto",
    attn_implementation="sdpa",
)
model.eval()

def generate(prompt: str, max_input: int = 2048, max_new: int = 256):
    inputs = tok(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=max_input,
    ).to(model.device)

    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new,
            do_sample=False,
            use_cache=True,
        )

    return tok.decode(out[0], skip_special_tokens=True)

print(generate("Explain KV cache in simple terms."))

이 상태에서 OOM이 나면, 우선순위는 보통 다음 순서가 효율적입니다.

  1. max_inputmax_new를 줄여서 “decode OOM”인지 확인
  2. 여전히 로딩이 불안정하면 4bit로 전환
  3. 동시 요청/배치를 제거
  4. FlashAttention 지원 환경이면 attn_implementation 교체 테스트

결론: OOM은 ‘모델 크기’보다 ‘길이’에서 터진다

로컬 LLM OOM은 흔히 “VRAM이 작아서”라고 결론 내리지만, 실제로는 컨텍스트 길이와 생성 길이, 그리고 동시성이 KV 캐시를 폭발시키는 경우가 많습니다.

  • 로딩이 안 되면: 8bit/4bit + device_map으로 가중치부터 줄이기
  • 생성 중 죽으면: KV 캐시 관점에서 입력/출력 길이와 동시성을 먼저 줄이기
  • 경계에서 불안정하면: SDPA/FlashAttention 등 커널 최적화로 피크를 낮추기

이 3단계를 체계적으로 적용하면, 같은 GPU에서도 “가끔 OOM”에서 “안정적으로 서빙”으로 넘어갈 수 있습니다.