Published on

Transformers 로컬 LLM 추론 2배 - flash-attn·KV캐시

Authors

로컬에서 Transformers로 LLM을 돌리다 보면, 같은 모델인데도 추론 속도가 들쭉날쭉하거나 토큰 생성이 생각보다 느린 경우가 많습니다. 이때 가장 큰 병목은 대개 attention 계산과, 디코딩 단계에서 반복되는 KV 캐시 접근 패턴입니다.

이 글에서는 추론 체감 2배를 목표로, (1) flash-attn 계열 커널로 attention을 가속하고 (2) KV 캐시를 올바르게 켜고 튜닝하는 방법을 중심으로 정리합니다. 끝부분에는 실제로 속도가 올랐는지 확인하는 측정 코드와, 자주 만나는 함정도 같이 다룹니다.

왜 디코딩이 느린가: 프리필과 디코딩의 비용 구조

LLM 생성은 크게 두 단계입니다.

  • Prefill(프롬프트 인코딩): 입력 시퀀스 길이 L에 대해 attention이 주로 O(L^2)로 커집니다.
  • Decode(토큰 1개씩 생성): 매 스텝마다 새 토큰 1개를 추가하며, 이전 토큰들의 K/V를 재사용합니다. 여기서 KV 캐시가 없으면 매번 과거 토큰까지 다시 계산해 사실상 매 스텝이 비싸집니다.

실무에서 “한 번 응답 시작하면 첫 토큰은 늦고 이후는 빠르다” 혹은 “토큰 생성이 끝까지 꾸준히 느리다” 같은 증상이 나오는데, 전자는 prefill 병목, 후자는 decode 병목(캐시 미사용, 커널 비최적화, 메모리 대역폭 등)일 가능성이 큽니다.

핵심 1: flash-attn으로 attention 커널 최적화

flash-attn이 주는 이점

flash-attn은 attention 연산을 GPU 메모리 접근까지 고려해 재구성한 커널로, 전통적인 구현 대비 다음 이점이 있습니다.

  • 중간 텐서(materialization) 감소로 메모리 트래픽 절감
  • 더 나은 타일링/퓨전으로 GPU 활용률 향상
  • 특히 긴 시퀀스 prefill에서 효과가 크게 나타나는 경우가 많음

Transformers에서는 모델/버전/환경에 따라 flash_attention_2 또는 sdpa(PyTorch Scaled Dot-Product Attention) 경로로 연결됩니다. 목표는 결국 빠른 attention 백엔드로 안정적으로 타게 하는 것입니다.

설치/환경 체크 포인트

환경별로 설치 난이도가 달라서, 먼저 아래를 확인하는 게 안전합니다.

  • GPU 아키텍처와 CUDA 버전
  • PyTorch 버전(특히 torch의 SDPA 지원)
  • Transformers 버전

실무 팁:

  • 최신 PyTorch는 SDPA 경로가 좋아져서, flash-attn을 직접 설치하지 않아도 sdpa만으로도 충분히 빨라지는 경우가 있습니다.
  • 다만 특정 모델/정밀도 조합에서는 flash_attention_2가 더 잘 나오는 케이스가 있어, 둘 다 옵션으로 준비해두면 좋습니다.

Transformers에서 attention 구현 선택하기

아래는 Transformers에서 attention 구현을 명시하는 예시입니다. 모델에 따라 지원 여부가 다르니, 실패 시 sdpa로 폴백하도록 구성하는 게 실전적입니다.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-2-7b-hf"  # 예시

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

# attention_implementation: "flash_attention_2" | "sdpa" | "eager"
attn_impl = "flash_attention_2"

try:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="cuda",
        attn_implementation=attn_impl,
    )
except Exception:
    # 환경에 따라 flash-attn이 안 타는 경우가 있어 폴백
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="cuda",
        attn_implementation="sdpa",
    )

model.eval()

제대로 적용됐는지 확인하는 방법

  • 가장 확실한 건 프로파일링이지만, 간단히는 모델 로딩 로그/구성값을 확인합니다.
  • 또한 동일 프롬프트로 prefill 시간을 비교하면 차이가 빠르게 드러납니다.

핵심 2: KV 캐시를 “정확히” 켜고, 메모리 레이아웃을 이해하기

KV 캐시란 무엇이고 왜 2배가 가능한가

디코딩 단계에서 매 토큰마다 attention은 과거 토큰들의 KV를 참조합니다. KV 캐시가 켜져 있으면:

  • 과거 토큰들의 K/V를 다시 계산하지 않고 재사용
  • 매 스텝 계산량이 크게 줄어듦

즉, 긴 응답을 생성할수록 KV 캐시의 효과가 커집니다.

Transformers에서 KV 캐시 활성화

대부분의 CausalLM은 기본적으로 캐시를 쓰지만, 설정/경로에 따라 꺼질 수 있습니다. 아래처럼 명시해두는 편이 안전합니다.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="cuda",
)
model.eval()

# 캐시 사용을 명시
model.generation_config.use_cache = True

prompt = "Explain KV cache in simple terms."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

with torch.inference_mode():
    out = model.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=False,
        use_cache=True,
    )

print(tokenizer.decode(out[0], skip_special_tokens=True))

캐시가 꺼지는 대표 케이스

다음 조건에서는 캐시가 비활성화되거나 기대만큼 이득이 줄어들 수 있습니다.

  • gradient_checkpointing이 켜진 상태(학습/파인튜닝 설정이 추론에 섞인 경우)
  • 일부 모델/설정에서 use_cache=False가 기본값으로 저장된 경우
  • 빔서치에서 빔 수가 커질 때 캐시 메모리 사용량이 급증하여 병목이 바뀌는 경우

KV 캐시의 메모리 비용: 속도와 VRAM의 트레이드오프

KV 캐시는 빠르지만 VRAM을 먹습니다. 대략적으로는:

  • 레이어 수, 헤드 수, head_dim, 시퀀스 길이에 비례
  • 생성 길이가 길수록 캐시가 계속 커짐

로컬 GPU에서 OOM이 난다면, 속도만 보지 말고 캐시/정밀도/양자화를 함께 봐야 합니다. 관련해서는 아래 글에서 4bit와 KV 캐시를 함께 다루며 OOM 회피 전략을 정리해두었습니다.

“2배”를 만들기 위한 실전 조합: 권장 설정 레시피

아래는 로컬 단일 GPU에서 가장 흔히 잘 먹히는 조합입니다.

레시피 A: FP16 또는 BF16 + flash_attention_2 + KV 캐시

  • 목표: 최대 속도
  • 조건: 비교적 여유 있는 VRAM
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cuda",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
)
model.eval()

prompt = "Write a short technical note about GPU memory bandwidth."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

with torch.inference_mode():
    y = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False,
        use_cache=True,
        # 가능하면 패딩 토큰/종료 토큰도 명시해 경고를 줄임
        pad_token_id=tokenizer.eos_token_id,
    )

print(tokenizer.decode(y[0], skip_special_tokens=True))

레시피 B: SDPA + KV 캐시 (flash-attn 설치가 어려울 때)

  • 목표: 설치 리스크 최소화
  • 조건: PyTorch 최신 권장
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cuda",
    torch_dtype=torch.float16,
    attn_implementation="sdpa",
)
model.generation_config.use_cache = True

레시피 C: 긴 컨텍스트에서 prefill 최적화 우선

긴 프롬프트(예: RAG로 문서 다량 투입)에서는 prefill이 전체 지연의 대부분을 차지합니다. 이때는:

  • flash_attention_2 또는 sdpa로 prefill을 먼저 줄이고
  • 필요하면 프롬프트를 chunking하거나, 검색 결과를 압축(summarize)해 입력 토큰 자체를 줄이는 게 더 큰 효과를 내기도 합니다.

서빙 환경에서 cold start나 스케일 이슈까지 얽히면, 모델 추론 최적화 외에도 운영 레벨 튜닝이 필요합니다. KServe를 쓰는 경우라면 아래 글이 같이 도움이 됩니다.

속도 측정: 토큰/초를 재는 최소 코드

최적화는 “느낌”이 아니라 숫자로 확인해야 합니다. 아래 코드는 prefill과 decode를 완전히 분리해 재긴 어렵지만, 실무에서 가장 간단히 **토큰 생성 속도(tokens/sec)**를 비교하는 용도로 충분히 쓸 수 있습니다.

주의: GPU 비동기 실행 때문에 torch.cuda.synchronize()를 넣어야 측정이 정확해집니다.

import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

@torch.inference_mode()
def benchmark_generate(model, tokenizer, prompt, max_new_tokens=256):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    torch.cuda.synchronize()
    t0 = time.time()

    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        use_cache=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    torch.cuda.synchronize()
    t1 = time.time()

    # 생성된 new tokens 수 추정: 전체 길이 - 입력 길이
    new_tokens = out.shape[-1] - inputs["input_ids"].shape[-1]
    dt = t1 - t0
    return new_tokens / dt, dt

model_id = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cuda",
    torch_dtype=torch.float16,
    attn_implementation="sdpa",
)
model.eval()

prompt = "Summarize the concept of KV cache in 3 bullet points."
tps, sec = benchmark_generate(model, tokenizer, prompt, max_new_tokens=256)
print(f"tokens/sec={tps:.2f}, elapsed={sec:.2f}s")

이제 attn_implementationflash_attention_2로 바꾸거나, use_cache=False로 바꿔가며 비교하면 병목이 어디인지 빠르게 감이 옵니다.

자주 터지는 함정과 해결 체크리스트

1) 토큰/초가 오히려 떨어진다

  • 프롬프트가 너무 짧으면 prefill 최적화 효과가 잘 안 보일 수 있습니다.
  • 생성 길이가 너무 짧아도 KV 캐시 이득이 작습니다.
  • GPU가 이미 다른 프로세스로 바쁘면 변동이 큽니다.

권장: 동일 조건으로 3회 이상 반복 측정하고 중앙값을 보세요.

2) VRAM은 남는데 느리다

  • compute가 아니라 메모리 대역폭이 병목일 수 있습니다.
  • KV 캐시 접근이 많아지는 decode는 특히 메모리 병목이 두드러집니다.

대응:

  • 더 낮은 정밀도(FP16, BF16, 8bit, 4bit) 고려
  • 배치/동시성 조정(서빙이라면 request batching)

3) OOM이 난다

  • KV 캐시가 누적되며 길이에 비례해 VRAM을 계속 사용합니다.
  • 긴 컨텍스트 + 긴 생성은 캐시가 빠르게 커집니다.

대응:

  • max_new_tokens 제한
  • 컨텍스트 길이 제한
  • 양자화 적용(위 내부 링크 글 참고)

4) 서빙에서 성능이 들쭉날쭉하다

로컬 단일 실행과 달리, 서빙은 동시 요청, 배치, 콜드스타트, 스케일 정책이 성능을 크게 흔듭니다. 특히 스케일0에서 올라오는 지연은 “모델이 느린 것”처럼 보이기도 합니다.

정리: 가장 확률 높은 2배 가속 루트

  • prefill이 느리면: attn_implementationflash_attention_2 또는 sdpa로 전환해 attention 커널 최적화부터 확인
  • decode가 느리면: use_cache=True가 확실히 적용됐는지 확인하고, 생성 길이를 늘려 효과를 측정
  • VRAM이 빡빡하면: KV 캐시의 메모리 비용을 감안해 8bit/4bit 등과 함께 설계

결국 “2배”는 단일 마법 옵션이 아니라, attention 커널KV 캐시를 올바른 경로로 태우고, 측정으로 병목을 확인하며, VRAM 제약 안에서 균형점을 찾을 때 가장 재현성 있게 달성됩니다.