Published on

Transformers 로컬 LLM 2배 가속 - KV 캐시·FA2

Authors

로컬에서 LLM을 돌리다 보면 “모델은 잘 뜨는데 왜 이렇게 느리지?”가 가장 흔한 벽입니다. 특히 대화형(챗) 워크로드는 토큰을 한 번에 길게 생성하기보다, 매 요청마다 짧게 생성하고(스트리밍), 직전 대화 히스토리를 다시 먹이는 패턴이 많습니다. 이때 성능을 결정하는 핵심은 크게 두 가지입니다.

  • KV 캐시(Key/Value cache): 이전 토큰들의 어텐션 K/V를 재사용해, 매 스텝마다 과거 토큰을 다시 계산하는 비용을 제거
  • FlashAttention2(FA2): 어텐션 연산을 메모리 효율적으로 재구성해, GPU에서 병목이 되는 메모리 트래픽을 줄이고 처리량을 올림

이 글은 Transformers 기반 로컬 추론에서 체감 2배를 노릴 때 가장 재현성이 높은 조합(= use_cache + FA2 + 올바른 설정/측정)을 중심으로 정리합니다. 추가로 정밀 최적화가 필요하면 INT8/FP8 계열로 넘어가야 하는데, 그 단계는 별도로 다루는 편이 좋습니다. 관련해서는 PyTorch에서 TensorRT INT8로 3배 가속하기도 함께 참고하면 좋습니다.

1) 먼저 병목을 분리하자: Prefill vs Decode

LLM 추론은 크게 두 구간으로 나뉩니다.

  • Prefill(프리필): 프롬프트(입력 토큰 전체)를 한 번에 통과시키는 구간
  • Decode(디코드): 다음 토큰을 1개씩(또는 작은 배치로) 반복 생성하는 구간

대화형에서 “느리다”는 불만의 대부분은 **Decode TPS(tokens/sec)**가 낮아서 발생합니다. KV 캐시와 FA2는 특히 Decode 구간에서 효과가 큽니다.

간단한 측정 루틴을 먼저 만들어 두면, 최적화가 실제로 먹히는지 바로 확인할 수 있습니다.

import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Llama-3.1-8B-Instruct"  # 예시

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
)
model.eval()

def measure(prompt: str, max_new_tokens: int = 128):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    torch.cuda.synchronize()
    t0 = time.time()
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            use_cache=True,
        )
    torch.cuda.synchronize()
    t1 = time.time()

    total_tokens = out.shape[-1]
    prompt_tokens = inputs["input_ids"].shape[-1]
    gen_tokens = total_tokens - prompt_tokens

    return {
        "prompt_tokens": prompt_tokens,
        "gen_tokens": gen_tokens,
        "sec": t1 - t0,
        "tok_per_sec": gen_tokens / (t1 - t0 + 1e-9),
    }

print(measure("Explain KV cache in 3 bullets.", 200))

이제부터는 이 측정값(특히 tok_per_sec)이 얼마나 올라가는지로 판단하면 됩니다.

2) KV 캐시: 원리와 “진짜로 켜졌는지” 확인법

2-1. KV 캐시가 하는 일

Transformer 디코딩은 매 스텝마다 “지금까지의 모든 토큰”에 대해 어텐션을 계산합니다. 캐시가 없으면 새 토큰 1개를 만들 때도 과거 토큰들의 K/V를 다시 만들고 다시 읽습니다.

KV 캐시는 과거 토큰들의 K/V를 레이어별로 저장해두고, 다음 스텝에서는 새 토큰에 해당하는 K/V만 추가합니다. 그래서 디코드 구간의 계산량이 크게 줄어듭니다.

2-2. Transformers에서 KV 캐시 활성화

대부분의 CausalLM은 기본적으로 use_cache=True가 켜져 있지만, 다음 경우에 꺼지거나 무력화될 수 있습니다.

  • gradient_checkpointing 같은 학습용 설정이 켜진 상태
  • 일부 모델/설정에서 config.use_cache=False
  • torch.compile 또는 특수한 래퍼에서 캐시 경로가 깨짐

확실히 하려면 model.config.use_cache를 강제로 맞추고, generate(..., use_cache=True)를 명시합니다.

model.config.use_cache = True

out = model.generate(
    **inputs,
    max_new_tokens=128,
    do_sample=False,
    use_cache=True,
)

2-3. 캐시가 실제로 쓰이는지 확인(디버깅)

가장 확실한 방법은 forward를 직접 호출해 past_key_values가 나오고, 다음 스텝에서 다시 입력했을 때 연산이 줄어드는지 확인하는 것입니다.

with torch.inference_mode():
    x = tokenizer("Hello", return_tensors="pt").to("cuda")
    y1 = model(**x, use_cache=True)
    pkv = y1.past_key_values

    # 다음 토큰(가짜로 1토큰)만 넣고 캐시를 전달
    next_id = torch.argmax(y1.logits[:, -1, :], dim=-1, keepdim=True)
    y2 = model(input_ids=next_id, past_key_values=pkv, use_cache=True)

print(type(pkv), len(pkv))

여기서 past_key_valuesNone이면 캐시 경로가 꺼진 것입니다.

2-4. 주의: KV 캐시는 “메모리로 속도를 사는” 최적화

KV 캐시는 VRAM을 꽤 먹습니다. 대략적으로 레이어 수, 헤드 수, 히든 차원, 시퀀스 길이에 비례해 증가합니다. 증상이 보통 이렇게 나타납니다.

  • 짧은 프롬프트에서는 빠른데, 대화가 길어질수록 점점 느려짐
  • 어느 순간 CUDA OOM

대화형 서비스에서 흔한 대응은 아래 중 하나입니다.

  • 히스토리 요약 또는 윈도잉(최근 N 토큰만 유지)
  • 배치 크기 축소
  • 더 작은 모델 또는 양자화로 VRAM 확보

3) FlashAttention2: “켜면 빨라지는” 조건과 설치 포인트

3-1. FlashAttention2가 이득이 큰 조건

FA2는 어텐션을 메모리 효율적으로 계산해 GPU에서 더 높은 처리량을 내는 구현입니다. 다만 항상 이득이 보장되진 않고, 보통 아래 조건에서 효과가 큽니다.

  • NVIDIA GPU(주로 Ampere 이상) + 적절한 CUDA 환경
  • bf16 또는 fp16 추론
  • 시퀀스 길이가 어느 정도 길거나, 디코드 스텝이 충분히 많은 워크로드

CPU 추론에는 해당이 없고, AMD/Apple 계열은 별도 경로가 필요합니다.

3-2. Transformers에서 FA2 활성화

Transformers는 모델 로딩 시 attn_implementation 옵션으로 어텐션 커널을 선택할 수 있습니다.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cuda",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
model.eval()

환경에 따라 flash-attn 패키지 설치가 필요합니다. 설치가 막히는 경우가 잦으니, 다음을 먼저 확인하세요.

  • PyTorch 버전과 CUDA 버전 호환
  • GPU Compute Capability
  • 컴파일 툴체인(특히 리눅스에서 ninja, gcc) 유무

설치 예시는 환경마다 달라서 여기서 단정하기 어렵지만, 원칙은 “현재 PyTorch/CUDA 조합에 맞는 flash-attn 빌드”입니다.

3-3. FA2가 실제로 적용됐는지 확인

가장 흔한 실패는 옵션을 줬는데 내부적으로 fallback 되는 케이스입니다(예: eager attention). 모델 로딩 직후 설정을 확인하거나, 로그/프로파일러로 커널이 바뀌었는지 확인합니다.

print(getattr(model.config, "attn_implementation", None))

더 확실히 하려면 torch.profiler로 attention 관련 커널이 flash 계열로 찍히는지 확인합니다.

4) KV 캐시 + FA2 조합에서 체감 2배를 만드는 설정들

여기부터는 “자잘하지만 차이를 만드는” 실전 항목입니다.

4-1. torch.inference_mode()model.eval()

추론에서 autograd를 완전히 끄면 메모리/오버헤드가 줄어듭니다.

model.eval()
with torch.inference_mode():
    out = model.generate(**inputs, max_new_tokens=256, use_cache=True)

4-2. dtype 선택: 가능하면 bf16

Ampere 이상이면 bf16이 안정적이고 속도도 잘 나옵니다. fp16은 일부 모델에서 수치 불안정이 생길 수 있습니다.

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
    attn_implementation="flash_attention_2",
)

4-3. generate 파라미터가 성능에 미치는 영향

  • do_sample=True는 샘플링 연산이 추가되지만 보통 병목은 어텐션이라 영향은 제한적입니다.
  • num_beams(빔서치)는 거의 항상 느려집니다(빔 수만큼 디코드가 늘어남).
  • max_new_tokens가 커질수록 KV 캐시/FA2 효과가 더 잘 드러납니다.

권장 출발점:

gen = model.generate(
    **inputs,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    num_beams=1,
    use_cache=True,
)

4-4. 배치 전략: “동시 요청”이 있으면 배치가 이득

로컬 챗봇이라도 요청이 동시 다발이면, 단일 스트림보다 작은 배치가 GPU를 더 잘 채워 TPS가 올라갈 수 있습니다. 단, 배치가 커지면 KV 캐시 메모리도 같이 늘어납니다.

서비스 형태로 묶을 때는 “동시성 제어 + 마이크로배칭”이 핵심인데, 캐시가 꼬이지 않도록 요청별 시퀀스를 분리해야 합니다.

5) 흔한 실패 사례와 체크리스트

5-1. “속도가 그대로”인 경우

  • 실제로는 CPU에서 돌고 있음: device_map="cuda" 확인
  • attn_implementation가 fallback 됨: 설정 확인 및 프로파일링
  • 프롬프트가 너무 짧아 측정 노이즈가 큼: max_new_tokens를 256 이상으로
  • 디코드가 아니라 프리필이 병목: 입력 길이가 지나치게 길면 프리필 최적화(프롬프트 압축/요약)가 먼저

5-2. OOM이 나는 경우

  • KV 캐시가 누적되는 대화형에서 자주 발생
  • 해결: 히스토리 윈도우, 배치 축소, 모델 축소, 양자화, 또는 더 큰 VRAM

5-3. 재현 가능한 성능 측정을 위한 팁

  • 측정 전 torch.cuda.synchronize()
  • 워밍업 1회 후 측정(첫 실행은 커널 로딩/캐시로 느림)
  • 동일 프롬프트/동일 max_new_tokens로 비교
def bench(prompt, n=3):
    measure(prompt, 64)  # warmup
    stats = [measure(prompt, 256) for _ in range(n)]
    tps = sum(s["tok_per_sec"] for s in stats) / n
    return tps, stats

print(bench("Write a short API design note about pagination."))

6) 추가 가속 로드맵: 2배 이후에 뭘 할까

KV 캐시와 FA2로도 부족하면 다음 단계로 넘어갑니다.

  • 양자화(INT8/INT4): VRAM을 줄여 더 큰 배치/더 긴 컨텍스트를 가능하게 하거나, 동일 VRAM에서 속도를 올림
  • TensorRT-LLM / vLLM / llama.cpp 계열: 목적이 “최고 TPS”라면 Transformers 단독보다 추론 엔진 교체가 더 큰 점프를 주는 경우가 많음

다만 이번 글의 목표는 “Transformers를 유지하면서” 비교적 안전하게 얻을 수 있는 2배급 개선이므로, 그 범위에서는 use_cache와 FA2가 가장 먼저 시도할 카드입니다. 양자화 쪽은 PyTorch에서 TensorRT INT8로 3배 가속하기에서 더 깊게 다룬 내용을 참고해도 좋습니다.

7) 결론: 가장 효율적인 적용 순서

로컬 LLM 추론 속도를 올릴 때, 시행착오가 적은 순서는 다음과 같습니다.

  1. 측정 루틴부터 만든다: tok_per_sec로 비교 가능하게
  2. KV 캐시가 확실히 켜졌는지 검증한다: past_key_values 확인
  3. FlashAttention2를 적용한다: attn_implementation 설정 + fallback 여부 확인
  4. dtype와 generate 설정을 정리한다: bf16, 빔서치 금지, 워밍업/동기화
  5. OOM과 장문 대화에 대비한다: 히스토리 윈도우/요약

이 조합만 제대로 맞아도 “로컬에서 답답함이 사라지는” 수준의 개선을 얻는 경우가 많습니다. 다음 글에서는 같은 환경에서 프롬프트 길이(컨텍스트) 증가가 성능에 미치는 영향과, KV 캐시 메모리 추정/윈도잉 전략을 더 구체적으로 다뤄보겠습니다.