Published on

Transformers 로컬 LLM OOM·속도 해결 - 4비트 실전

Authors

로컬 GPU에서 LLM을 돌리다 보면 거의 항상 두 가지 벽을 만납니다. 첫째는 VRAM OOM, 둘째는 생각보다 느린 생성 속도입니다. 특히 transformers의 기본 로딩 방식으로 7B 이상 모델을 FP16으로 올리면, 로딩은 되더라도 컨텍스트를 조금만 늘려도 메모리가 급격히 증가해 추론이 불안정해집니다.

이 글에서는 4비트 양자화(4bit quantization)를 중심으로 OOM을 줄이면서도 체감 속도를 개선하는 방법을 정리합니다. 단순히 load_in_4bit=True만 켜는 수준이 아니라, 실제로 병목이 되는 KV 캐시, attention 구현, dtype 조합까지 함께 다룹니다.

관련해서 OOM을 더 촘촘히 막는 KV 캐시 전략은 아래 글도 같이 보면 좋습니다.

왜 OOM이 나는가: 가중치보다 KV 캐시가 더 무섭다

LLM 추론에서 VRAM을 먹는 주범은 크게 두 가지입니다.

  1. 모델 가중치(weights)
  2. KV 캐시(Key/Value cache)

가중치는 4비트 양자화로 크게 줄일 수 있습니다. 하지만 KV 캐시는 배치 크기, 시퀀스 길이, 레이어 수, 헤드 수, head_dim, dtype에 비례해 계속 커집니다. 즉, “모델은 4비트로 줄였는데도 OOM”이 나는 대표 원인이 KV 캐시입니다.

대략적인 감을 잡기 위한 간단한 체크리스트는 다음과 같습니다.

  • 컨텍스트 길이(max_new_tokens가 아니라 입력+생성 합) 증가가 OOM을 유발하는가
  • 배치가 2 이상일 때만 터지는가
  • 스트리밍을 끄면 덜 터지는가(토큰을 모아서 한 번에 반환하면 peak가 커질 수 있음)

이 글에서는 우선 “가중치 메모리”를 4비트로 확 줄이고, 그 다음 “KV 캐시로 인한 OOM과 속도”를 튜닝하는 순서로 접근합니다.

준비물: bitsandbytes, accelerate, 최신 transformers

4비트 양자화는 보통 bitsandbytes 기반으로 진행합니다. CUDA가 있는 리눅스 환경에서 가장 안정적이며, 윈도우는 환경에 따라 설치 난이도가 있을 수 있습니다.

pip install -U transformers accelerate bitsandbytes

버전 이슈가 의심되면 아래도 같이 확인하세요.

python -c "import torch, transformers, bitsandbytes as bnb; print(torch.__version__); print(transformers.__version__); print(bnb.__version__)"

4비트 로딩 기본: NF4 + double quant 추천 조합

현업에서 가장 무난한 조합은 다음입니다.

  • 4비트 양자화 타입: nf4
  • 연산 dtype: bfloat16(가능하면) 또는 float16
  • double_quant=True

nf4는 LLM 가중치 분포에 유리한 방식으로 알려져 있고, double quant는 추가 압축으로 VRAM을 더 줄여줍니다.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "meta-llama/Llama-2-7b-chat-hf"  # 예시

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # GPU가 bf16 지원하면 권장
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.eval()

torch_dtypebnb_4bit_compute_dtype를 같이 쓰는 이유

  • torch_dtype는 일부 모듈 로딩 dtype에 영향을 줄 수 있습니다.
  • bnb_4bit_compute_dtype는 4비트로 저장된 가중치를 실제 연산할 때 어떤 dtype로 계산할지 결정합니다.

여기서 compute dtype을 float16으로 낮추면 VRAM은 비슷하지만 일부 GPU에서 속도나 안정성이 달라질 수 있습니다. Ampere 이상이면 bfloat16이 성능과 안정성에서 유리한 경우가 많습니다.

속도가 느린 이유 1: attention 구현이 기본값이면 손해

LLM 추론 속도는 attention 구현에 크게 좌우됩니다. 최신 transformers에서는 GPU에서 다음 옵션들이 중요합니다.

  • SDPA(Scaled Dot Product Attention)
  • Flash Attention 계열

모델에 따라 attn_implementation을 지정할 수 있습니다.

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation="sdpa",
)

환경에 따라 sdpa가 자동으로 켜지기도 하지만, 명시하면 “내가 지금 어떤 attention 경로를 타는지”를 통제하기 쉬워집니다.

attention 최적화는 Stable Diffusion에서도 VRAM과 속도에 큰 영향을 주는데, 접근 방식이 유사합니다.

속도가 느린 이유 2: KV 캐시를 제대로 쓰지 못하면 매 토큰이 재계산된다

generate는 기본적으로 KV 캐시를 활용하지만, 설정이나 입력 형태에 따라 캐시 효율이 떨어질 수 있습니다.

체크 포인트:

  • use_cache=True가 켜져 있는가
  • 프롬프트를 매 토큰마다 다시 인코딩하거나, 매번 모델을 새로 호출하고 있지 않은가
  • 배치 처리 시 padding이 과도하지 않은가

기본적으로는 아래처럼 명시해두는 편이 안전합니다.

inputs = tokenizer("Explain KV cache in one paragraph.", return_tensors="pt").to(model.device)

with torch.inference_mode():
    out = model.generate(
        **inputs,
        max_new_tokens=128,
        do_sample=False,
        use_cache=True,
    )

print(tokenizer.decode(out[0], skip_special_tokens=True))

컨텍스트가 길수록 KV 캐시가 커져 OOM이 난다

4비트 양자화는 가중치 메모리만 줄입니다. KV 캐시는 보통 FP16 또는 BF16으로 유지되기 때문에, 컨텍스트를 무작정 늘리면 OOM이 다시 발생합니다.

실무적으로는 다음 우선순위로 조정합니다.

  1. max_new_tokens를 줄인다
  2. 입력 프롬프트를 줄인다(특히 시스템 프롬프트, 히스토리)
  3. 배치를 줄인다
  4. 필요하면 KV 캐시 전략(예: sliding window, paged attention 계열)을 고려한다

KV 캐시를 포함한 OOM 방지 전략을 더 깊게 다룬 글은 아래를 참고하세요.

실전: OOM과 속도를 같이 잡는 권장 generate 템플릿

아래 템플릿은 “일단 로컬에서 안정적으로 빠르게”를 목표로 잡은 설정입니다.

  • 4bit NF4
  • compute dtype BF16
  • SDPA
  • inference_mode
  • 불필요한 샘플링 옵션 제거
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "meta-llama/Llama-2-7b-chat-hf"  # 예시

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    quantization_config=bnb_config,
    attn_implementation="sdpa",
)
model.eval()

prompt = "You are a helpful assistant. Summarize the benefits of 4-bit quantization for local LLM inference."
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

gen_kwargs = dict(
    max_new_tokens=200,
    do_sample=False,
    use_cache=True,
    repetition_penalty=1.05,
)

with torch.inference_mode():
    output = model.generate(**inputs, **gen_kwargs)

text = tokenizer.decode(output[0], skip_special_tokens=True)
print(text)

OOM이 계속 난다면: 4비트 이후에 점검할 것들

4비트로도 터진다면 대개 “KV 캐시” 또는 “피크 메모리” 문제입니다. 아래를 순서대로 확인하세요.

1) 배치와 패딩부터 줄이기

배치가 1이어도, 입력이 길면 OOM이 날 수 있습니다. 배치가 2 이상이면 KV 캐시가 거의 선형으로 증가합니다. 또 padding이 과도하면 실질 시퀀스가 늘어나 메모리를 더 씁니다.

  • 가능하면 batch를 줄이고
  • 여러 요청을 묶을 때는 길이가 비슷한 요청끼리 묶습니다

2) max_new_tokens를 보수적으로 잡기

로컬 서비스에서 흔한 실수는 “상한을 크게 잡아두고 사용자가 알아서 끝내겠지”입니다. 모델이 장문 모드로 들어가면 KV 캐시가 계속 쌓여 OOM이 납니다.

  • 기본값을 128 또는 256 정도로 낮게
  • 필요할 때만 늘리기

3) 메모리 파편화 대응: torch.cuda.empty_cache()는 만능이 아니다

torch.cuda.empty_cache()는 “할당자 캐시 반환”일 뿐, 근본적으로 peak를 낮추지 못합니다. 오히려 호출 타이밍에 따라 성능이 흔들릴 수 있습니다.

대신 다음을 권장합니다.

  • 프로세스를 오래 살리는 서비스라면 워밍업 후 steady state 유지
  • 요청마다 모델을 재로딩하지 않기
  • 프롬프트 템플릿을 고정하고 길이를 통제하기

4) CPU offload는 최후의 수단

device_map="auto"는 VRAM이 부족하면 일부 레이어를 CPU로 내릴 수 있습니다. OOM은 피하지만, 토큰 생성 속도가 급격히 떨어질 수 있습니다.

  • 데스크톱 환경에서 “일단 돌아가게”는 가능
  • 실시간 서비스 목적이면 지양

속도 최적화 팁: 측정부터 제대로 하기

속도 개선은 감으로 하면 실패합니다. 최소한 아래 두 가지 지표는 분리해서 보세요.

  • 첫 토큰까지 시간(TTFT)
  • 토큰 생성 속도(tokens per second)

간단 측정 예시는 다음과 같습니다.

import time
import torch

start = time.time()
with torch.inference_mode():
    output = model.generate(**inputs, max_new_tokens=128, do_sample=False, use_cache=True)
end = time.time()

num_tokens = output.shape[-1] - inputs["input_ids"].shape[-1]
print("seconds:", end - start)
print("new_tokens:", num_tokens)
print("tok_per_sec:", num_tokens / (end - start))

TTFT가 느리면 보통 다음이 원인입니다.

  • 모델이 cold 상태(초기 커널 컴파일, 캐시 미스)
  • CPU offload 발생
  • 디스크에서 모델 읽기 병목

서비스 환경에서 cold start가 문제라면, GPU 추론 플랫폼에서도 비슷한 패턴으로 튜닝합니다.

정리: 4비트는 시작점이고, 승부는 KV 캐시와 attention에서 난다

  • OOM의 1차 해결책은 4비트 양자화(nf4 + double quant)
  • 속도는 attention 구현(sdpa 등)과 KV 캐시 효율에 크게 좌우
  • 4비트로도 OOM이면 대부분 컨텍스트 길이, 배치, KV 캐시가 원인
  • 측정 지표를 TTFT와 tokens per second로 나눠서 튜닝해야 체감이 좋아짐

로컬 LLM은 “돌아간다”와 “안정적으로 빠르다” 사이의 간극이 큽니다. 4비트로 가중치를 줄인 다음, 컨텍스트 정책과 KV 캐시 전략을 함께 설계하면 OOM과 속도 문제를 동시에 현실적으로 해결할 수 있습니다.