Published on

Transformers 로컬 LLM 2배 가속 - FlashAttention2

Authors

로컬에서 LLM을 돌릴 때 체감 성능을 가장 크게 갉아먹는 구간은 대개 attention 입니다. 특히 긴 컨텍스트(예: 4k~32k)에서 토큰을 생성할수록 KV 캐시와 어텐션 연산이 병목이 되기 쉽습니다. 이때 FlashAttention2는 어텐션을 더 효율적인 커널로 바꿔 메모리 접근을 줄이고 GPU를 더 잘 태워서, 동일 모델에서도 토큰 생성 속도를 유의미하게 끌어올릴 수 있습니다.

이 글에서는 Hugging Face Transformers 기반 로컬 추론 환경에서 FlashAttention2를 적용하는 실전 절차를 다룹니다. 설치 조건, 코드 적용 포인트, 속도 측정 방법, 그리고 가장 흔하게 부딪히는 오류를 함께 정리합니다.

FlashAttention2가 빨라지는 이유(핵심만)

기본 scaled dot-product attention은 대략 다음 흐름을 가집니다.

  • QK^T 계산
  • softmax 적용
  • softmax(QK^T) V 계산

이 과정에서 중간 텐서가 커지고(특히 시퀀스 길이가 길수록), GPU 메모리 대역폭과 캐시 미스가 병목이 됩니다. FlashAttention2는 이 중간 텐서를 크게 만들지 않도록 타일링/퓨전된 커널로 계산해, 메모리 트래픽을 줄이고 연산 효율을 높입니다.

결과적으로 다음 상황에서 이득이 커지는 경향이 있습니다.

  • 컨텍스트가 길다(프롬프트가 길거나, 대화 히스토리가 길거나)
  • 배치가 어느 정도 있다(동시 요청, 혹은 batched generation)
  • GPU가 충분히 강하다(Ampere 이상에서 특히 체감)

적용 전 체크리스트(환경 조건)

FlashAttention2는 “아무 GPU에서나” 되는 옵션이 아닙니다. 아래를 먼저 확인하세요.

1) GPU 아키텍처

일반적으로 NVIDIA Ampere(A100, RTX 30) 이상에서 가장 안정적으로 쓰입니다. Turing(RTX 20)도 일부 케이스에서 되지만, 조합에 따라 실패하거나 기대만큼 이득이 안 나올 수 있습니다.

2) CUDA, PyTorch, Transformers 조합

  • PyTorch는 CUDA 빌드가 되어 있어야 합니다.
  • Transformers는 attn_implementation 옵션을 지원하는 버전이 필요합니다.

아래로 빠르게 확인합니다.

import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.get_device_name(0))

3) dtype 권장

추론에서는 보통 float16 또는 bfloat16이 권장됩니다.

  • RTX 30/40 계열: float16이 무난
  • A100/H100: bfloat16도 좋음

설치: flash-attn(FlashAttention2) 준비

FlashAttention2는 보통 flash-attn 패키지로 설치합니다. 이 패키지는 컴파일이 필요할 수 있어, 환경에 따라 설치 난이도가 갈립니다.

방법 A: pip 설치(가능하면 가장 간단)

pip install -U flash-attn --no-build-isolation
  • --no-build-isolation은 빌드 시 PyTorch/CUDA 헤더를 찾는 문제를 줄이는 데 도움이 되는 경우가 많습니다.

방법 B: 설치가 자주 깨질 때 점검 포인트

설치가 실패한다면 보통 아래 중 하나입니다.

  • CUDA toolkit 버전 불일치
  • PyTorch CUDA 버전과 로컬 CUDA toolkit이 맞지 않음
  • ninja, gcc/g++ 등 빌드 도구 부족

빌드 도구는 대략 아래가 필요합니다.

pip install -U ninja
# Ubuntu 예시
sudo apt-get update
sudo apt-get install -y build-essential

컨테이너로 운영할 계획이라면, 추론 서빙 관점에서는 배포/콜드스타트까지 함께 고려하는 게 좋습니다. GPU 서빙을 다룬 글로는 KServe로 GPU LLM 배포 - 콜드스타트 0에 가깝게도 함께 참고하면 연결이 됩니다.

Transformers에서 FlashAttention2 켜는 법

Transformers는 모델 로딩 시 어텐션 구현을 선택할 수 있습니다. 핵심은 attn_implementationflash_attention_2를 지정하는 것입니다.

아래 예시는 로컬에서 AutoModelForCausalLM로 로딩해 생성 속도를 올리는 가장 일반적인 패턴입니다.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Llama-2-7b-hf"  # 예시

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="cuda",
    attn_implementation="flash_attention_2",
)

prompt = "Explain the key idea of FlashAttention2 in one paragraph."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

with torch.inference_mode():
    out = model.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=False,
        use_cache=True,
    )

print(tokenizer.decode(out[0], skip_special_tokens=True))

포인트 1) use_cache=True는 사실상 필수

생성 단계에서는 KV 캐시를 켜야 합니다. FlashAttention2가 빨라도 캐시를 끄면 생성이 급격히 느려집니다.

포인트 2) device_map과 dtype

  • 단일 GPU면 device_map="cuda"가 간단합니다.
  • 멀티 GPU 샤딩이면 device_map="auto"를 쓰되, FlashAttention2가 모델 샤딩 환경에서 잘 동작하는지(특히 커스텀 구조) 확인이 필요합니다.

포인트 3) 모델별 지원 여부

대부분의 대표 Causal LM 계열은 잘 되지만, 모델 구현이 특이하거나 attention이 커스텀으로 바뀐 경우 적용이 안 되거나 fallback이 일어날 수 있습니다.

제대로 적용됐는지 확인하는 방법

가장 확실한 방법은 “속도 측정”이지만, 그 전에 설정이 실제로 먹었는지 간단히 확인할 수 있습니다.

1) 로딩 로그/경고 확인

환경에 따라 Transformers가 FlashAttention2 적용 실패 시 경고를 출력하거나, 내부적으로 다른 구현으로 fallback할 수 있습니다. 실행 시 표준 출력 경고를 꼭 확인하세요.

2) 간단 벤치마크 코드(토큰/초)

아래 코드는 동일 프롬프트로 워밍업 후 토큰 생성 속도를 대략 비교할 수 있습니다.

import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

@torch.inference_mode()
def bench(model, tokenizer, prompt, max_new_tokens=256, n=5):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    # warmup
    _ = model.generate(**inputs, max_new_tokens=32, do_sample=False, use_cache=True)
    torch.cuda.synchronize()

    times = []
    for _ in range(n):
        t0 = time.time()
        out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, use_cache=True)
        torch.cuda.synchronize()
        t1 = time.time()

        gen_tokens = out.shape[-1] - inputs["input_ids"].shape[-1]
        times.append(gen_tokens / (t1 - t0))

    return sum(times) / len(times)

model_id = "meta-llama/Llama-2-7b-hf"
prompt = "Write a detailed checklist for deploying a local LLM." * 50  # 길게 해서 차이를 보기 좋게

tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)

# baseline
m0 = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="cuda",
    attn_implementation="sdpa",  # 또는 기본값
)

# flash-attn2
m1 = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="cuda",
    attn_implementation="flash_attention_2",
)

s0 = bench(m0, tok, prompt)
s1 = bench(m1, tok, prompt)

print(f"sdpa tok/s: {s0:.2f}")
print(f"fa2  tok/s: {s1:.2f}")
print(f"speedup: {s1/s0:.2f}x")
  • 프롬프트를 길게 잡아야 차이가 잘 납니다.
  • 측정 전 torch.cuda.synchronize()로 타이밍 정확도를 확보합니다.

체감 2배를 만들려면: 병목도 같이 정리하기

FlashAttention2만 켰는데도 기대만큼 안 빨라진다면, 병목이 어텐션이 아닐 수 있습니다. 로컬 추론에서 흔한 병목을 같이 점검하세요.

1) CPU 토크나이징이 느린 경우

  • use_fast=True로 Fast tokenizer 사용
  • 배치가 크면 토크나이징을 병렬화하거나, 입력 파이프라인을 최적화

2) 샘플링 옵션이 과도한 경우

top_p, top_k, temperature 자체가 큰 비용은 아니지만, 로직이 복잡해지고 반복 호출이 늘면 미세하게 영향을 줍니다. 성능 테스트는 do_sample=False로 고정해 비교하는 편이 좋습니다.

3) 메모리 부족으로 인한 스로틀링

VRAM이 부족하면 다음이 발생합니다.

  • 활성화/캐시가 밀려서 성능 급락
  • 일부가 CPU 오프로딩되어 지연 증가

이 경우는 FlashAttention2보다 먼저 양자화나 KV 캐시 전략을 검토해야 합니다. 양자화 튜닝을 하다가 정확도 문제가 생기는 케이스도 많아서, 관련해서는 PyTorch PTQ·QAT 정확도 급락 원인·복구 7단계처럼 “성능 최적화가 정확도를 깨는 지점”도 함께 보면서 접근하는 게 안전합니다.

자주 만나는 오류와 해결

오류 1) flash_attn 설치 실패(컴파일 에러)

증상

  • nvcc를 못 찾는다
  • CUDA_HOME 관련 에러
  • gcc 버전 문제

대응

  • PyTorch CUDA 버전과 호환되는 CUDA toolkit을 사용
  • 컨테이너라면 CUDA 개발 이미지(런타임 말고 dev) 기반으로 빌드
  • ninja, build-essential 설치

오류 2) FlashAttention2가 적용되지 않고 조용히 느림

가능 원인

  • GPU 아키텍처 미지원
  • dtype이 float32
  • 모델 구조가 해당 커널 경로를 타지 않음

대응

  • torch_dtypefloat16 또는 bfloat16으로
  • 프롬프트 길이를 늘려 다시 벤치마크
  • Transformers 버전 업데이트 후 재시도

오류 3) 긴 컨텍스트에서 OOM

FlashAttention2가 메모리를 줄여주긴 하지만, KV 캐시는 생성 토큰 수와 레이어 수에 비례해 계속 커집니다.

대응

  • max_new_tokens를 줄이거나
  • 모델 크기를 낮추거나
  • 양자화(특히 weight-only) 또는 더 큰 VRAM GPU 사용

운영 관점 팁: 로컬에서 빨라진 다음 단계

로컬에서 속도가 만족스럽게 나오면, 다음은 “어떻게 안정적으로 서비스하느냐”가 문제가 됩니다.

  • 모델 워밍업 및 GPU 점유 전략
  • 오토스케일링 시 cold start
  • 요청 큐잉과 배치 처리

이 단계는 단순히 커널 최적화가 아니라 서빙 아키텍처 문제로 넘어가며, GPU 서빙 시스템을 고려한다면 앞서 언급한 KServe로 GPU LLM 배포 - 콜드스타트 0에 가깝게가 도움이 됩니다.

정리

  • FlashAttention2는 Transformers 로컬 LLM 추론에서 어텐션 병목을 줄여 토큰 생성 속도를 크게 올릴 수 있습니다.
  • 적용은 flash-attn 설치 후, 모델 로딩 시 attn_implementationflash_attention_2를 지정하는 것이 핵심입니다.
  • “정말 2배”를 만들려면 프롬프트 길이, dtype, KV 캐시, VRAM 여유, 토크나이징 병목까지 같이 점검해야 합니다.

같은 모델/같은 GPU에서도 워크로드(프롬프트 길이, 배치, 생성 길이)에 따라 가속 비율이 크게 달라지니, 위 벤치마크 코드로 본인 환경에서 tok/s를 먼저 수치로 잡고 튜닝을 진행하는 것을 권장합니다.