- Published on
Transformers 로컬 LLM CUDA OOM, 4bit·KV캐시 최적화
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
로컬 GPU에서 Transformers로 LLM을 띄우다 보면, 모델 가중치를 4bit로 줄였는데도 CUDA out of memory가 나는 경우가 많습니다. 이유는 단순히 “모델 파라미터가 커서”가 아니라, **추론 시점에 추가로 생기는 메모리(특히 KV 캐시, activation, 임시 버퍼)**가 프롬프트 길이와 생성 길이에 따라 급격히 늘기 때문입니다.
이 글에서는 (1) 4bit 로딩으로 가중치 메모리를 줄이고, (2) KV 캐시와 컨텍스트/생성 길이를 제어하며, (3) attention 구현 및 메모리 옵션을 선택해 OOM을 실질적으로 줄이는 방법을 정리합니다. OOM이 GPU가 아니라 노드/컨테이너 레벨에서 터지는 케이스까지 포함해, 어디서 메모리가 새는지 확인하는 체크리스트도 제공합니다.
관련해서, GPU가 아니라 시스템 메모리 압박으로 프로세스가 죽는다면 리눅스 OOM Killer로 프로세스 죽을 때 원인 추적도 함께 보시면 원인 분리가 빨라집니다.
CUDA OOM이 “4bit인데도” 나는 핵심 원인
1) KV 캐시가 토큰 수에 비례해 커진다
오토리그레시브 생성에서 매 토큰마다 각 레이어는 Key/Value를 캐시에 쌓습니다. 이 KV 캐시는 프롬프트 토큰 수 + 생성 토큰 수에 비례해 커집니다.
- 프롬프트가 길수록(대화 히스토리, RAG 컨텍스트) KV 캐시가 커짐
max_new_tokens가 클수록 KV 캐시가 더 커짐- 배치 사이즈가 1이어도, 긴 컨텍스트면 OOM 가능
즉, 가중치를 4bit로 줄여도 KV 캐시가 FP16/BF16으로 유지되면 전체 메모리의 주범이 될 수 있습니다.
2) attention 구현/임시 버퍼가 생각보다 크다
PyTorch SDPA, FlashAttention, xFormers 등 구현에 따라 임시 버퍼 크기와 피크 메모리가 달라집니다. 같은 모델이라도 attention 경로가 바뀌면 OOM이 사라지거나 반대로 더 자주 발생합니다.
3) device_map="auto"는 “GPU 메모리만” 해결하지 않는다
device_map="auto"로 일부 레이어를 CPU로 오프로딩하면 GPU OOM은 줄 수 있지만,
- CPU RAM이 부족해지거나
- PCIe 전송으로 지연이 커지거나
- 오프로딩 버퍼가 커져서 전체 메모리 압박이 생길 수 있습니다.
먼저: 지금 메모리가 어디서 터지는지 빠르게 계측
PyTorch에서 GPU 메모리 피크 확인
아래처럼 생성 전후로 피크 메모리를 찍으면 “가중치 로딩에서 터지는지” vs “생성 중에 터지는지”를 분리할 수 있습니다.
import torch
def report(tag: str):
torch.cuda.synchronize()
alloc = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
peak = torch.cuda.max_memory_allocated() / 1024**3
print(f"[{tag}] alloc={alloc:.2f}GB reserved={reserved:.2f}GB peak={peak:.2f}GB")
torch.cuda.reset_peak_memory_stats()
report("start")
# model load ...
report("after_load")
# generate ...
report("after_generate")
after_load에서 이미 피크가 높으면: 가중치/오프로딩/로딩 설정 문제after_generate에서 피크가 급증하면: KV 캐시, 컨텍스트/생성 길이, attention 경로 문제
nvidia-smi는 “스냅샷”이라 피크를 놓칠 수 있다
생성 중 순간 피크가 OOM을 만들면 nvidia-smi로는 잘 안 잡힙니다. 위의 max_memory_allocated를 같이 보세요.
1단계: 4bit 로딩으로 가중치 메모리부터 확실히 줄이기
Transformers에서 가장 흔한 선택은 bitsandbytes 기반 4bit(NF4) 로딩입니다.
4bit(NF4) 기본 예제
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # 예시
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, # Ampere+ 권장, 아니면 torch.float16
bnb_4bit_use_double_quant=True,
)
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.eval()
체크 포인트
bnb_4bit_compute_dtype는 “연산 dtype”입니다. BF16이 가능한 GPU(Ampere 이상)면 BF16이 안정적인 편입니다.device_map="auto"는 편하지만, VRAM이 타이트하면 레이어 배치가 비효율적일 수 있어 수동으로 조정할 여지가 있습니다.
OOM이 로딩에서 터지면
low_cpu_mem_usage=True를 추가해 CPU 메모리 피크를 줄여보세요.- safetensors 사용 모델이면 로딩 피크가 줄 때가 많습니다.
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
low_cpu_mem_usage=True,
)
2단계: KV 캐시가 OOM의 주범일 때 줄이는 방법
가장 효과적인 레버: 토큰 수를 줄인다
KV 캐시는 토큰 수에 비례하므로, 아래 3가지를 먼저 조입니다.
- 프롬프트 길이 축소: 히스토리 요약, RAG top-k 축소, 불필요한 시스템 프롬프트 제거
max_new_tokens축소: “길게 생성”이 필요하면 스트리밍 + 중간 요약 전략을 고려- 배치/동시성 축소: 동시에 여러 요청을 처리하면 KV 캐시가 요청 수만큼 늘어납니다
max_new_tokens와 입력 길이를 함께 제어하는 예제
import torch
prompt = """You are a helpful assistant. Summarize the following text ..."""
inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
use_cache=True,
)
print(tok.decode(out[0], skip_special_tokens=True))
truncation=True와max_length는 “입력 토큰” 상한을 강제합니다.use_cache=True는 보통 속도를 위해 켜지만, 메모리 압박이 심하면 다음 옵션을 검토합니다.
극단 처방: use_cache=False
KV 캐시를 끄면 메모리는 줄지만, 매 토큰마다 과거 토큰을 다시 계산하므로 속도가 크게 느려집니다. “OOM을 피해야만 하는 디버깅/응급 상황”에서만 권장합니다.
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
use_cache=False,
)
KV 캐시 자체를 더 효율적으로: GQA/MQA 모델 선택
같은 파라미터 규모라도 GQA(Grouped Query Attention) 또는 MQA(Multi-Query Attention) 설계를 가진 모델은 KV 헤드 수가 줄어 KV 캐시가 작아지는 경향이 있습니다. 모델 선택 단계에서 “긴 컨텍스트를 로컬에서 돌릴 것”이라면 아키텍처도 고려 대상입니다.
3단계: attention 경로 최적화(FlashAttention/SDPA)
Transformers는 환경과 설정에 따라 attention 구현이 바뀝니다. 일반적으로 FlashAttention 계열이 메모리/속도에 유리한 경우가 많지만, 조합에 따라 오히려 호환성 문제나 OOM 패턴이 달라질 수 있습니다.
attn_implementation 지정 예제
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
attn_implementation="sdpa", # 또는 "flash_attention_2"
)
실전 팁
- 동일 프롬프트/동일
max_new_tokens로sdpa와flash_attention_2를 번갈아 테스트해 피크 메모리를 비교하세요. - 커스텀 커널(FlashAttention2)은 드라이버, CUDA, PyTorch 버전에 민감합니다. 설치가 꼬이면 성능보다 안정성이 우선입니다.
4단계: 오프로딩과 device_map을 “의도적으로” 쓰기
VRAM이 정말 부족하면 CPU 오프로딩이 답일 수 있습니다. 다만 무작정 auto에 맡기기보다, 목표를 정해야 합니다.
- 목표 A: “무조건 OOM 회피” (속도는 느려져도 됨)
- 목표 B: “속도 유지하면서 OOM 회피” (일부만 오프로딩)
max_memory로 상한을 걸어 배치를 유도
from transformers import AutoModelForCausalLM
max_memory = {
0: "20GiB", # GPU 0
"cpu": "48GiB" # 시스템 RAM
}
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
max_memory=max_memory,
)
이 설정은 “GPU에 우겨 넣다가 터지는” 상황을 줄이고, 처음부터 CPU로 일부를 보내도록 유도합니다.
5단계: 생성 파라미터로 피크 메모리 줄이기
빔서치/다중 시퀀스는 메모리를 크게 먹는다
num_beams를 키우면 beam 수만큼 상태를 유지하므로 메모리가 증가합니다.num_return_sequences도 마찬가지입니다.
가능하면 단일 시퀀스, 그리디 또는 낮은 샘플링으로 시작하세요.
out = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
num_beams=1,
num_return_sequences=1,
)
배치 추론을 한다면 “동시성 제한”이 최우선
서버 형태로 로컬 LLM을 돌릴 때는, 평균 메모리가 아니라 동시 요청에서의 피크로 OOM이 납니다. 큐잉으로 동시성을 제한하거나, 요청별 max_new_tokens 상한을 강제하세요. (서빙 레벨 이슈는 KServe LLM 서빙 503·스케일0 지연 해결법처럼 “플랫폼 병목”으로도 이어질 수 있습니다.)
6단계: 파편화(fragmentation)와 PyTorch allocator 튜닝
메모리가 “총량은 남아 있는데” 할당이 실패하는 경우가 있습니다. 이때는 파편화 가능성이 큽니다.
PYTORCH_CUDA_ALLOC_CONF로 파편화 완화
아래는 흔히 쓰는 옵션입니다.
export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:128,garbage_collection_threshold:0.8"
max_split_size_mb를 너무 작게/크게 잡으면 오히려 역효과가 날 수 있어, 64~256MB 사이에서 실험이 필요합니다.- 장시간 프로세스(서버)에서 특히 효과를 보는 경우가 있습니다.
7단계: “GPU OOM”처럼 보이지만 사실은 CPU/RAM 문제인 경우
4bit 로딩과 오프로딩을 섞으면 CPU RAM 사용량이 예상보다 커질 수 있습니다. 컨테이너/노드에서 메모리 제한이 빡빡하면, GPU OOM 이전에 프로세스가 죽거나 재시작될 수 있습니다.
- 컨테이너 메모리 limit 확인
- dmesg에서 OOM killer 로그 확인
이런 케이스는 리눅스 OOM Killer로 프로세스 죽을 때 원인 추적에 나온 방식대로 “누가 죽였는지”부터 확인하는 게 빠릅니다.
문제 상황별 처방전(우선순위)
A) 모델 로딩 단계에서 OOM
- 4bit(NF4) +
low_cpu_mem_usage=True device_map="auto"+max_memory로 GPU 상한 강제- 그래도 안 되면 더 작은 모델 또는 더 큰 VRAM
B) 짧은 프롬프트도 생성 중 OOM
max_new_tokens축소num_beams=1,num_return_sequences=1attn_implementation을sdpa또는flash_attention_2로 바꿔 피크 비교- 파편화 의심 시
PYTORCH_CUDA_ALLOC_CONF적용
C) 긴 컨텍스트(RAG/대화 히스토리)에서만 OOM
- 입력
max_length강제 + 히스토리 요약 - RAG top-k 축소, chunk 길이 축소
- 필요하면
use_cache=False를 응급으로 사용(속도 저하 감수)
마무리: “4bit는 시작일 뿐, KV 캐시가 본게임”
로컬 LLM에서 CUDA OOM을 줄이려면 가중치(4bit)만 줄이는 것으로는 부족하고, 토큰 수(입력+출력)와 KV 캐시 성장을 함께 관리해야 합니다. 실무적으로는 아래 조합이 가장 재현성 있게 효과를 냅니다.
- 4bit(NF4) 로딩 + BF16 연산 dtype
- 입력
max_length상한 +max_new_tokens상한 num_beams=1로 시작- attention 구현(
sdpa/flash_attention_2)을 바꿔 피크 메모리 비교 - 장시간 서빙이면 allocator 튜닝과 동시성 제한
위 순서대로 적용하면, 같은 GPU에서도 “돌아가긴 도는” 상태를 넘어서, 지속적으로 안정적인 로컬 추론에 가까워질 수 있습니다.