- Published on
Transformers 로컬 LLM VRAM OOM, 4bit+FA2로 잡기
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서버급 GPU가 없어도 로컬에서 LLM을 돌릴 수 있는 시대지만, transformers로 모델을 로드하고 몇 번 생성해보면 곧바로 CUDA out of memory를 만나는 경우가 많습니다. 특히 4bit 양자화를 적용했는데도 OOM이 나는 케이스는 대부분 “가중치 메모리”가 아니라 KV 캐시, 어텐션 구현, dtype, 시퀀스 길이, 배치/동시성에서 터집니다.
이 글은 로컬 추론 환경에서 자주 겪는 VRAM OOM을 재현 가능한 방식으로 진단하고, 4bit + FlashAttention2 조합을 중심으로 메모리를 실제로 줄이는 방법을 정리합니다.
- 4bit인데도 OOM이 나는 이유: KV 캐시가 지배적이거나, FA2가 비활성화되어 기본 SDPA/기본 어텐션 경로로 돌아가는 경우
- 해결 전략: 4bit로 가중치 축소 + FA2로 어텐션 워크스페이스/중간 텐서 축소 + 생성 설정으로 KV 캐시 압력 완화
관련해서 4bit 양자화 자체를 더 깊게 보고 싶다면 PyTorch 2.0 PTQ로 LLM 4bit 양자화 실전도 함께 보면 좋습니다.
OOM의 진짜 범인: 가중치가 아니라 KV 캐시인 경우
4bit 양자화는 주로 모델 가중치(weight) 메모리를 줄입니다. 하지만 디코더 계열 LLM의 추론에서 VRAM을 잡아먹는 큰 덩어리는 종종 KV 캐시입니다.
KV 캐시는 대략 다음에 비례합니다.
- 레이어 수
- 히든 차원 및 헤드 수
batch_sizeseq_len(프롬프트 길이) +max_new_tokens(생성 길이)- KV 캐시 dtype(대개
fp16또는bf16)
즉, 7B 모델을 4bit로 낮춰도 입력 프롬프트가 길거나 생성 토큰을 길게 잡으면 KV 캐시가 VRAM을 다 먹고 OOM이 납니다.
추가로, 어텐션 구현이 비효율적인 경로로 가면(예: eager attention) 중간 텐서가 커져 OOM이 더 빨리 발생합니다. 여기서 FlashAttention2가 의미가 있습니다.
4bit + FlashAttention2 조합의 역할 분담
- 4bit(보통 bitsandbytes NF4): 가중치 메모리를 크게 줄임
- FlashAttention2: 어텐션 계산을 타일링/퓨전해서 중간 활성화 및 워크스페이스를 줄이고, 보통 속도도 개선
중요한 포인트는 “4bit만으로는 해결이 안 되는 OOM”이 꽤 많고, 이때 FA2가 켜져 있으면 같은 조건에서 버티는 경우가 많다는 점입니다.
다만 FA2는 설치/버전/아키텍처 조건이 맞지 않으면 자동으로 비활성화되거나 다른 경로로 폴백될 수 있습니다. 그래서 “FA2를 쓴다고 생각했는데 실제로는 안 쓰고 있었다”가 흔한 함정입니다.
실전: Transformers에서 4bit + FlashAttention2로 로드하기
아래 예시는 로컬 GPU에서 AutoModelForCausalLM을 4bit로 로드하고, 어텐션 구현을 FlashAttention2로 지정하는 최소 구성입니다.
주의: 글 본문에서 부등호 문자가 노출되면 MDX에서 빌드 에러가 날 수 있어, 비교 연산 등은 코드 블록 안에서만 사용합니다.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
model_id = "meta-llama/Llama-2-7b-chat-hf" # 예시
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16, # Ampere 이상이면 bf16 권장
)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model.eval()
prompt = "한국어로 FlashAttention2가 VRAM을 줄이는 이유를 설명해줘."
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
use_cache=True,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
체크리스트 1: FA2가 실제로 활성화됐는지 확인
환경이 맞지 않으면 attn_implementation을 줘도 내부적으로 다른 경로로 갈 수 있습니다. 다음을 함께 확인하세요.
flash-attn패키지가 설치되어 있는지- GPU 아키텍처/드라이버/CUDA가 호환되는지
- PyTorch 버전이 너무 낮거나 너무 새로운 조합에서 빌드가 꼬이지 않았는지
간단한 확인 코드는 모델 설정을 찍어보는 것입니다.
print("attn_implementation:", getattr(model.config, "attn_implementation", None))
모델/버전에 따라 config 필드명이 다를 수 있어, 로그에 경고가 없는지도 같이 보세요. FA2가 안 켜져 있으면 OOM 임계점이 확 낮아집니다.
체크리스트 2: dtype 혼합으로 의도치 않은 메모리 증가 방지
4bit를 쓰더라도 연산 dtype이 fp32로 떠버리면 메모리가 늘고 속도도 떨어집니다.
bnb_4bit_compute_dtype를bf16또는fp16로 고정torch_dtype도 동일 계열로 맞추기
특히 소비자 GPU에서 bf16 지원이 애매한 경우가 있으니, 문제가 있으면 torch.float16으로 바꿔 재현해보는 게 좋습니다.
여전히 OOM이면: 생성 설정부터 줄여서 원인 분리
OOM을 “모델이 커서”라고 단정하기 전에, 아래 3가지를 먼저 줄이면 원인 분리가 됩니다.
max_new_tokens를 줄인다- 프롬프트 길이를 줄인다(긴 RAG 컨텍스트는 특히 위험)
- 동시 요청 수(배치/멀티스레드/서빙 워커)를 1로 만든다
예시:
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=64,
do_sample=False,
use_cache=True,
)
이렇게 했을 때는 돌아가는데 max_new_tokens=1024에서만 죽는다면, 가중치가 아니라 KV 캐시 압력이 핵심입니다.
VRAM 사용량을 코드로 계측하기
“진짜로 어디서 늘었는지”를 보려면 생성 전후로 VRAM을 찍어보는 게 가장 빠릅니다.
import torch
def vram_mb():
return {
"allocated": torch.cuda.memory_allocated() / 1024 / 1024,
"reserved": torch.cuda.memory_reserved() / 1024 / 1024,
"max_allocated": torch.cuda.max_memory_allocated() / 1024 / 1024,
}
torch.cuda.reset_peak_memory_stats()
print("before:", vram_mb())
with torch.inference_mode():
out = model.generate(**inputs, max_new_tokens=256, do_sample=False)
print("after:", vram_mb())
allocated가 급증하면 실제 텐서가 늘어난 것reserved만 큰데allocated가 낮으면 캐싱/프래그먼테이션 영향
프래그먼테이션이 심하면 같은 총량이어도 OOM이 더 빨리 날 수 있습니다.
FlashAttention2가 OOM을 줄이는 구체 포인트
일반적인 어텐션은 QK^T와 softmax, softmax * V 과정에서 큰 중간 텐서를 만들기 쉽습니다. FlashAttention2는 이를 타일 단위로 처리하고 커널 퓨전을 통해 중간 텐서를 크게 줄이는 방식이라, 다음에서 이점이 큽니다.
- 긴 시퀀스에서 어텐션 중간 텐서 폭발 억제
- 메모리 대역폭 효율 증가로 속도 개선
단, KV 캐시 자체는 “저장해야 하는 과거 정보”라서 FA2만으로 완전히 사라지지 않습니다. 그래서 4bit와 FA2를 같이 쓰되, 생성 설정도 같이 조정해야 합니다.
4bit에서도 자주 터지는 패턴과 처방
패턴 1: 긴 프롬프트(RAG) + 긴 생성
- 증상: 첫 토큰 생성은 되는데, 생성이 길어질수록 VRAM이 계속 증가하다 OOM
- 처방:
max_new_tokens제한- 컨텍스트 윈도우를 줄이거나, 문서 chunk 수를 줄임
- 가능하면 요약 후 질의(2단계)로 프롬프트 길이를 억제
패턴 2: 동시성(서빙)에서 워커가 여러 개
로컬에서도 uvicorn 워커를 늘리거나, 멀티프로세스로 여러 요청을 받으면 GPU 메모리가 워커 수만큼 복제/증가합니다.
- 처방:
- GPU당 프로세스 1개를 기본으로
- 동적 배칭은 “프로세스 수”가 아니라 “요청 합치기”로 해결
서빙 관점의 튜닝은 Triton 쪽으로 확장되는데, 추론 지연과 배칭을 다루는 글로는 Triton 배포 후 지연 폭증 - 동적 배칭·인스턴스 튜닝도 참고할 만합니다.
패턴 3: reserved만 비대해지는 프래그먼테이션
- 증상: 작업을 반복할수록
memory_reserved가 커지고, 어느 순간 OOM - 처방:
- 같은 프로세스에서 모델을 반복 로드/언로드하지 않기
- 실험 루프라면 프로세스 재시작으로 정리
- PyTorch CUDA allocator 설정을 점검(고급)
설치/환경 이슈로 인한 “FA2 미적용”을 줄이는 방법
환경이 꼬이면 FA2 설치가 실패하거나, 설치는 됐는데 런타임에서 폴백될 수 있습니다. 다음처럼 버전을 명시하고, 설치 후 간단한 스모크 테스트를 권장합니다.
pip install -U "transformers" "accelerate" "bitsandbytes"
# flash-attn은 환경 의존성이 강해서, CUDA/torch에 맞는 설치 방법을 프로젝트 가이드에 맞춰 적용
스모크 테스트는 “OOM이 덜 나는지”가 아니라 “경고 없이 해당 커널 경로로 가는지”를 보는 게 핵심입니다.
OOM이 GPU가 아니라 시스템 메모리에서 터질 때
로컬 LLM을 돌리다 보면 VRAM OOM이 아니라 시스템 RAM OOM으로 프로세스가 죽는 경우도 있습니다. 예를 들어
- CPU offload를 과하게 켠 경우
- 데이터셋/문서 로딩이 누적되는 경우
- 여러 프로세스가 동시에 모델을 로드한 경우
이때는 커널 OOM killer 로그로 원인 프로세스를 특정하는 게 빠릅니다. 필요하면 리눅스 OOM Killer 로그로 원인 프로세스 찾기를 참고하세요.
권장 조합(로컬 추론 기준) 요약
- 4bit:
nf4+ double quant + compute dtype를bf16또는fp16 - Transformers:
attn_implementation을flash_attention_2로 명시 - 생성 파라미터:
max_new_tokens를 보수적으로 시작, 프롬프트 길이 관리 - 동시성: GPU당 프로세스 1개, 필요하면 동적 배칭으로 해결
- 계측:
torch.cuda.max_memory_allocated()로 전후 비교
마무리
로컬 LLM VRAM OOM은 “모델이 커서”만이 아니라, 대부분 KV 캐시와 어텐션 구현 경로에서 발생합니다. 4bit는 가중치를 줄여 출발점을 낮추고, FlashAttention2는 어텐션 중간 텐서 부담을 줄여 OOM 임계점을 올립니다. 여기에 생성 길이/프롬프트 길이/동시성을 함께 제어하면, 같은 GPU에서도 체감 안정성이 크게 달라집니다.
다음 단계로는
- 긴 컨텍스트가 필요한 워크로드에서 프롬프트 압축 전략 적용
- 서빙 환경에서 배칭과 워커 모델을 재설계
같은 최적화를 이어가면, “로컬에서도 실사용 가능한” 수준까지 충분히 끌어올릴 수 있습니다.