- Published on
Transformers 로컬 LLM OOM·속도 2배 튜닝법
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
로컬 GPU에서 Transformers 기반 LLM을 돌리다 보면 두 가지가 항상 발목을 잡습니다. 첫째는 CUDA out of memory(OOM), 둘째는 “돌긴 도는데 너무 느림”입니다. 이 글은 같은 모델을 더 작은 VRAM에서 안정적으로 돌리고, 토큰 생성 속도를 2배 수준까지 끌어올리는 실전 튜닝 순서를 정리합니다.
핵심은 “한 방” 옵션이 아니라 메모리(정적 가중치) + 런타임(activations/KV cache) + 커널(attention/linear) + 실행 그래프(compile) 를 각각 줄이는 것입니다.
아래 레시피는 Transformers + PyTorch + CUDA 기준이며, 모델은 Llama 계열(예: 7B~13B)에서 특히 효과가 큽니다.
1) OOM부터 잡자: VRAM을 먹는 3대 원인
LLM 추론에서 VRAM 사용은 크게 3종류입니다.
- 가중치(Weights): 모델 파라미터 자체. dtype에 따라 선형적으로 변합니다.
- KV 캐시(K/V cache): 생성 토큰이 늘어날수록 계속 쌓이는 메모리. 긴 컨텍스트에서 폭증합니다.
- 중간 activation: 보통
torch.no_grad()/inference_mode()면 작지만, attention 구현/배치/시퀀스 길이에 따라 커질 수 있습니다.
OOM을 “근본적으로” 줄이려면 가중치 dtype/양자화와 KV 캐시 정책을 먼저 손대는 게 가장 효율적입니다.
2) 가장 먼저 적용할 기본 세팅(안정성 + 메모리)
2.1 torch.inference_mode()로 불필요한 그래프 제거
추론에서 autograd 그래프가 남아 있으면 메모리도 느는 데다 속도도 손해입니다.
import torch
@torch.inference_mode()
def generate(model, tokenizer, prompt: str, **gen_kwargs):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
out = model.generate(**inputs, **gen_kwargs)
return tokenizer.decode(out[0], skip_special_tokens=True)
inference_mode()는 no_grad()보다 더 공격적으로 메타데이터를 줄여주는 경우가 많습니다.
2.2 dtype는 bf16 또는 fp16로 고정
Ampere 이상(A100, RTX 30, L4, A10 등)이면 bf16이 정확도/안정성에서 유리한 경우가 많습니다.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "meta-llama/Llama-2-7b-hf" # 예시
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16, # 또는 torch.float16
device_map="cuda",
)
model.eval()
여기서 torch_dtype를 안 박아두면, 환경에 따라 fp32로 로드되어 VRAM이 바로 터지는 케이스가 자주 나옵니다.
3) OOM의 진짜 범인: KV 캐시 줄이는 법
긴 컨텍스트에서 OOM이 나는 이유는 대부분 KV 캐시가 토큰 수에 비례해서 커지기 때문입니다. 특히 max_new_tokens를 크게 주거나, 프롬프트가 길면 빠르게 한계에 도달합니다.
3.1 생성 길이 제한은 “성능 최적화”가 아니라 “OOM 방지 장치”
text = generate(
model, tokenizer,
prompt,
max_new_tokens=256,
do_sample=False,
use_cache=True,
)
max_new_tokens를 합리적으로 제한하면 KV 캐시 상한이 생기고, OOM이 재현 불가능해지는 경우가 많습니다.
3.2 배치/동시성은 KV 캐시를 곱하기로 키운다
- 배치 크기
batch_size를 2로 올리면 KV 캐시도 거의 2배로 증가합니다. - 스트리밍 서버에서 동시 요청을 늘리면 “갑자기” OOM이 납니다.
로컬 챗봇/CLI라면 배치 1을 기본으로 두고, 속도는 다른 방법(아래 섹션)으로 끌어올리는 게 안전합니다.
3.3 컨텍스트 윈도우를 무작정 키우지 말 것
RoPE scaling이나 긴 컨텍스트 설정은 편리하지만, 그만큼 KV 캐시가 커집니다. “길게 넣고 요약해서 다시 넣는” 방식이 더 싸게 먹힙니다.
장문 대화 메모리 전략은 아래 글도 참고할 만합니다.
4) 속도 2배 체감 포인트 1: Attention 커널 바꾸기(SDPA/Flash)
LLM 추론에서 가장 비싼 연산 중 하나가 attention입니다. PyTorch 2 계열에서는 SDPA(Scaled Dot-Product Attention) 경로를 타면 커널이 크게 빨라질 수 있습니다.
4.1 SDPA를 강제/유도하기
환경/모델에 따라 자동으로 최적 경로를 타기도 하지만, 명시적으로 설정하는 편이 재현성이 좋습니다.
import torch
# PyTorch 2.x
# flash_sdp / mem_efficient_sdp는 GPU/드라이버에 따라 가용성이 달라집니다.
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
- Flash attention 계열이 가능하면 속도 이득이 큽니다.
- 불가능하면 mem-efficient 경로라도 타는 게 낫습니다.
Stable Diffusion 쪽이지만, “attention 커널 선택으로 VRAM/속도 최적화”라는 관점에서 아래 글의 맥락이 유사합니다.
5) 속도 2배 체감 포인트 2: torch.compile()로 그래프 최적화
PyTorch 2의 torch.compile()은 모델/환경에 따라 추론 속도를 꽤 올려주는 옵션입니다. 다만 첫 실행에 컴파일 비용이 들고, 일부 모델/연산에서 호환성 이슈가 있을 수 있습니다.
import torch
model = torch.compile(model, mode="reduce-overhead")
권장 팁:
- 서버라면 “워밍업 요청”을 1~2회 넣어 컴파일을 미리 끝내세요.
mode는reduce-overhead부터 시도하고, 안정적이면max-autotune도 실험해 볼 만합니다.
속도 향상이 애매하면, 대부분은 attention 커널/양자화에서 더 큰 차이가 납니다.
6) OOM과 속도를 동시에 잡는 정답: 8bit/4bit 양자화
VRAM이 부족한 로컬 환경에서 가장 강력한 방법은 가중치 양자화입니다.
- 8bit: 품질 손실이 상대적으로 적고 안정적
- 4bit: VRAM 절감이 크지만 모델/설정에 따라 품질/속도 특성이 달라짐
6.1 bitsandbytes 4bit 로딩 예시
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
model_id = "meta-llama/Llama-2-7b-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="cuda",
)
model.eval()
이 조합은 “일단 돌아가게 만들기”에 매우 강합니다. 특히 12GB VRAM에서 7B~13B급을 노릴 때 체감이 큽니다.
6.2 배포 관점의 INT8(PT2E)도 옵션
서비스 배포/재현성 관점에서 PyTorch의 PT2E 기반 INT8 흐름을 고려할 수도 있습니다.
7) “왜 이렇게 느리지?”를 만드는 흔한 실수 7가지
- 프롬프트가 너무 김: 프리필(prefill)이 느려져 첫 토큰까지 시간이 길어집니다.
- 샘플링 옵션 과다:
top_p,temperature자체가 느리진 않지만, 디버그를 어렵게 합니다. 우선do_sample=False로 고정하고 측정하세요. - CPU offload가 섞임:
device_map="auto"에서 일부 레이어가 CPU로 가면 지옥이 열립니다.nvidia-smi로 PCIe 전송이 튀는지 확인하세요. - 토크나이저가 병목: 대량 요청이면 토크나이저가 CPU에서 병목이 됩니다.
use_fast=True를 기본으로. - 불필요한 디코딩 반복: 매 토큰마다 전체 문자열을 decode하면 느립니다. 스트리밍이면 토큰 단위 처리/버퍼링을 고려하세요.
max_new_tokens를 과도하게 크게: 속도 저하뿐 아니라 KV 캐시로 OOM을 유발합니다.- 측정 방법이 잘못됨: 첫 실행은 커널 로딩/컴파일 때문에 느립니다. 워밍업 후 측정하세요.
8) 재현 가능한 벤치마크 코드(토큰/초 측정)
“빨라졌는지”를 감으로 판단하면 튜닝이 끝나지 않습니다. 최소한 토큰/초를 찍어야 합니다.
import time
import torch
from transformers import TextStreamer
@torch.inference_mode()
def bench(model, tokenizer, prompt: str, max_new_tokens: int = 256):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# 워밍업
_ = model.generate(**inputs, max_new_tokens=16, do_sample=False, use_cache=True)
torch.cuda.synchronize()
t0 = time.time()
out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, use_cache=True)
torch.cuda.synchronize()
t1 = time.time()
# 생성된 토큰 수(대략)
gen_tokens = out.shape[-1] - inputs["input_ids"].shape[-1]
tok_per_s = gen_tokens / (t1 - t0)
return {
"gen_tokens": int(gen_tokens),
"seconds": float(t1 - t0),
"tok_per_s": float(tok_per_s),
"text": tokenizer.decode(out[0], skip_special_tokens=True),
}
이 벤치로 다음을 비교해 보세요.
bf16vsfp16- SDPA 설정 on/off
torch.compile()on/off- 4bit 양자화 on/off
보통 “2배”는 한 옵션으로 나오기보다, attention 최적화 + 양자화(또는 compile) 조합에서 달성되는 경우가 많습니다.
9) 실전 추천 적용 순서(체크리스트)
torch_dtype를bf16또는fp16로 고정@torch.inference_mode()적용max_new_tokens와 동시성(배치)을 줄여 OOM 재현을 끊기- SDPA/Flash 경로 유도(가능한 GPU라면 가장 큰 속도 체감)
- VRAM이 부족하면 4bit 또는 8bit 양자화로 전환
- 마지막으로
torch.compile()실험(서빙이면 워밍업 포함)
이 순서대로만 해도 “OOM 때문에 못 쓰던 로컬 LLM”이 실용 영역으로 들어오고, 토큰 생성 속도도 꽤 공격적으로 올라갑니다.
10) 부록: GGUF 모델을 Transformers로 로드할 때 주의
로컬에서 GGUF 파일을 바로 Transformers로 로딩하려다 에러를 겪는 경우가 자주 있습니다. 포맷/로더가 다르기 때문입니다. 아래 글에 흔한 오류와 해결 흐름을 정리해 두었습니다.
마무리
OOM은 대개 “모델이 커서”가 아니라 dtype 미설정, KV 캐시 폭증, 동시성 곱셈 때문에 발생합니다. 속도는 attention 커널(SDPA/Flash) 과 양자화, 그리고 상황에 따라 torch.compile()이 가장 큰 레버리지입니다.
원하는 목표가 “내 GPU에서 안정적으로, 그리고 빠르게”라면 위 체크리스트를 그대로 따라가며 벤치마크로 검증해 보세요. 같은 하드웨어에서도 체감이 확 바뀝니다.