Published on

Transformers 로컬 LLM 느림? FlashAttention2 적용법

Authors
Binance registration banner

로컬 GPU에서 transformers로 LLM을 돌리는데 토큰 생성 속도가 기대보다 느리면, 가장 먼저 의심할 만한 지점이 어텐션 구현입니다. 기본 scaled dot-product attention은 메모리 트래픽이 크고, 특히 긴 컨텍스트에서 병목이 심해지기 쉽습니다.

FlashAttention2는 어텐션을 타일링하고 GPU SRAM을 적극 활용해 메모리 접근을 줄이는 방식으로, 같은 모델이라도 토큰 생성 속도와 VRAM 사용량을 눈에 띄게 개선하는 경우가 많습니다. 이 글에서는 transformers에서 FlashAttention2를 적용하는 실전 절차와, 적용이 안 되거나 오히려 느려질 때 체크할 포인트를 정리합니다.

참고로, 로컬 추론 최적화는 FlashAttention 외에도 TensorRT, 양자화, 컴파일 등 선택지가 많습니다. ONNX 및 TensorRT로 넘어가다 생기는 삽질 포인트는 별도 글인 PyTorch→ONNX→TensorRT INT8 양자화 오류 해결도 함께 참고하면 좋습니다.

왜 로컬 LLM이 느릴까: 병목을 먼저 분리하기

체감상 “느리다”는 현상은 크게 두 가지로 나뉩니다.

  1. 프리필(prefill) 구간이 느림: 입력 프롬프트를 한 번에 인코딩하는 단계가 오래 걸립니다. 긴 컨텍스트일수록 두드러집니다.
  2. 디코딩(decode) 구간이 느림: 토큰을 한 개씩 생성하는 단계가 느립니다. KV 캐시, 배치 크기, 샘플링 설정, GPU utilization에 영향을 받습니다.

FlashAttention2는 주로 어텐션 연산 자체를 최적화하므로, 특히 긴 컨텍스트의 프리필에서 이득이 크고, 디코딩에서도 모델과 환경에 따라 개선이 나올 수 있습니다. 다만 아래 조건이면 효과가 제한되거나 적용 자체가 안 될 수 있습니다.

  • GPU가 너무 구형이거나(아키텍처 제약)
  • dtype이 맞지 않거나(fp16 또는 bf16 권장)
  • transformers 버전과 flash-attn 빌드가 안 맞거나
  • 모델이 FlashAttention2 경로를 지원하지 않거나

적용 전 체크리스트

1) GPU와 CUDA 환경

  • NVIDIA GPU 필요(대부분의 경우)
  • CUDA Toolkit 및 드라이버 호환
  • PyTorch가 해당 CUDA 버전으로 설치되어 있어야 함

아래로 현재 환경을 빠르게 확인합니다.

import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_device_capability(0))

get_device_capability는 예를 들어 8.0, 8.6, 9.0 같은 형태로 나오며, 일반적으로 Ampere 이상에서 안정적인 성능 이득을 기대하는 편입니다.

2) dtype과 메모리

FlashAttention2는 보통 float16 또는 bfloat16에서 잘 동작합니다. float32로 돌리면 속도도 느리고 VRAM도 크게 먹습니다.

  • 소비자 GPU에서 bf16이 애매하면 fp16부터 시도
  • 서버 GPU(A100, H100 등)면 bf16도 좋은 선택

3) 모델이 지원하는지

최근 transformers는 많은 디코더 계열 모델에서 attn_implementation 옵션으로 FlashAttention2를 선택할 수 있습니다. 다만 모델 아키텍처나 버전에 따라 지원 여부가 다릅니다.

FlashAttention2 설치: 가장 흔한 실패 지점

FlashAttention2는 보통 flash-attn 패키지로 설치합니다. 문제는 이 패키지가 CUDA 확장 빌드를 포함하기 때문에, PyTorch CUDA 버전, 컴파일러, CUDA Toolkit, 그리고 OS 환경이 어긋나면 설치가 실패하거나 런타임에서 터집니다.

권장: 가상환경에서 버전 고정

python -m venv .venv
source .venv/bin/activate
pip install -U pip

PyTorch 설치 확인

이미 설치되어 있다면 넘어가도 되지만, CUDA 포함 빌드인지 확인하세요.

python -c "import torch; print(torch.__version__, torch.version.cuda)"

flash-attn 설치

환경에 따라 설치 커맨드는 달라질 수 있지만, 기본적으로는 아래를 시도합니다.

pip install -U flash-attn --no-build-isolation
  • --no-build-isolation은 종종 빌드 의존성 충돌을 줄여줍니다.
  • 설치가 실패하면 에러 로그에 nvcc, gcc, ninja, CUDA_HOME 등이 언급되는지 확인하세요.

설치 확인

python -c "import flash_attn; print('flash_attn ok')"

Transformers에서 FlashAttention2 켜기

핵심은 모델 로딩 시 attn_implementation을 지정하는 것입니다. 또한 dtype을 float16 또는 bfloat16로 맞추고, 가능하면 device_map을 사용해 GPU에 올립니다.

아래 예시는 일반적인 causal LM 추론 코드입니다.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Llama-2-7b-hf"  # 예시

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="cuda",
    attn_implementation="flash_attention_2",
)

prompt = "Explain FlashAttention2 in simple terms."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

with torch.inference_mode():
    out = model.generate(
        **inputs,
        max_new_tokens=128,
        do_sample=False,
    )

print(tokenizer.decode(out[0], skip_special_tokens=True))

적용 여부 확인 팁

  • 속도가 눈에 띄게 빨라졌는지(특히 긴 프롬프트에서)
  • VRAM 사용량이 줄었는지
  • 경고 로그에 “flash attention” 관련 메시지가 있는지

또는 모델 설정에 구현이 기록되는 경우도 있습니다.

print(getattr(model.config, "attn_implementation", None))

벤치마크: 토큰/초로 확인하기

체감만으로 판단하면 샘플링 옵션, 온도, 반복 실행 시 캐시 워밍 등으로 착시가 생깁니다. 아래처럼 토큰 생성 속도를 간단히 재보는 편이 좋습니다.

import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Llama-2-7b-hf"  # 예시

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

def load(attn_impl):
    return AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="cuda",
        attn_implementation=attn_impl,
    )

def bench(model, prompt, new_tokens=256):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    torch.cuda.synchronize()
    t0 = time.time()
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=new_tokens,
            do_sample=False,
        )
    torch.cuda.synchronize()
    dt = time.time() - t0

    total_tokens = out.shape[-1] - inputs["input_ids"].shape[-1]
    return total_tokens / dt

prompt = "Write a detailed technical explanation of KV cache and attention." * 20

model_eager = load("eager")
print("eager tok/s:", bench(model_eager, prompt))

del model_eager
torch.cuda.empty_cache()

model_fa2 = load("flash_attention_2")
print("fa2 tok/s:", bench(model_fa2, prompt))
  • 프롬프트를 길게 만들어야 프리필 병목 차이가 잘 드러납니다.
  • torch.cuda.synchronize()를 넣어야 측정이 정확해집니다.

자주 겪는 문제와 해결 체크포인트

1) 적용했는데도 속도가 그대로인 경우

프롬프트가 너무 짧다

짧은 입력에서는 어텐션 최적화 이득이 작고, 오히려 오버헤드가 더 커 보일 수 있습니다. 긴 컨텍스트에서 다시 측정하세요.

디코딩 병목은 어텐션만의 문제가 아니다

토큰 생성은 다음 요소도 크게 좌우합니다.

  • 샘플링 설정(예: top_p, temperature)
  • 배치 크기
  • max_new_tokens
  • use_cache 활성화 여부

generate는 기본적으로 use_cache를 쓰지만, 모델 설정에 따라 달라질 수 있으니 확인합니다.

print("use_cache:", getattr(model.config, "use_cache", None))

2) 설치는 됐는데 런타임 에러가 나는 경우

대표적으로 CUDA 커널 로딩 실패, 심볼 불일치, ABI 문제 등이 있습니다. 이런 경우는 대개 아래 중 하나입니다.

  • PyTorch CUDA 버전과 flash-attn 빌드 대상 CUDA가 다름
  • 드라이버가 너무 낮음
  • 여러 CUDA가 섞여 LD_LIBRARY_PATH가 꼬임

해결 전략은 단순합니다.

  • PyTorch를 먼저 원하는 CUDA 버전으로 재설치
  • 그 다음 flash-attn을 재설치
  • 가능하면 깨끗한 가상환경에서 재현

빌드/캐시가 꼬였을 때는 Docker를 쓰는 것도 방법인데, 캐시가 기대대로 동작하지 않으면 Docker BuildKit 캐시가 안 먹는 9가지 원인 같은 체크리스트가 도움이 됩니다.

3) attn_implementation 옵션이 먹지 않는 경우

모델/버전에 따라 옵션명이 다르거나, 해당 아키텍처가 FlashAttention2 경로를 제공하지 않을 수 있습니다.

  • transformers를 최신으로 업데이트
pip install -U transformers accelerate
  • 모델이 flash_attention_2를 지원하는지 릴리즈 노트나 이슈를 확인
  • 그래도 안 되면 eager 또는 sdpa로 바꿔 성능을 비교
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="cuda",
    attn_implementation="sdpa",
)

여기서 sdpa는 PyTorch의 scaled dot-product attention 경로를 의미하며, 환경에 따라 꽤 좋은 성능이 나올 수도 있습니다.

성능을 더 끌어올리는 실전 팁

1) torch.compile은 신중하게

PyTorch 2 계열에서 torch.compile로 이득을 보는 경우도 있지만, 모델/드라이버/커널 조합에 따라 컴파일 시간이 과도하거나 오히려 느려질 수 있습니다. FlashAttention2부터 안정화한 뒤, 별도 실험으로 접근하는 편이 안전합니다.

2) 긴 컨텍스트에서 VRAM이 빡빡하면

FlashAttention2는 메모리 효율이 좋아 긴 입력에서 특히 유리합니다. 그래도 VRAM이 부족하면 아래를 함께 고려합니다.

  • max_new_tokens를 줄여 디코딩 길이 제한
  • 더 작은 모델
  • 4bit 또는 8bit 양자화(예: bitsandbytes)

3) 로컬 서비스로 올릴 때는 병목이 다른 곳에 생긴다

로컬에서 단일 프로세스로 돌릴 때는 GPU가 병목이지만, API 서버로 감싸면 다음이 병목이 되기도 합니다.

  • 토크나이저 CPU 사용률
  • 동시 요청 처리 방식
  • 스트리밍 응답 구현

웹 프론트가 느려지는 병목을 진단하는 방식은 LLM과는 다르지만, “병목을 측정으로 쪼개서 해결한다”는 접근은 같습니다. 프론트 렌더 병목을 체계적으로 보는 글로는 Next.js INP 폭증? React 렌더 병목 7단계 진단이 참고가 됩니다.

결론: FlashAttention2는 가장 먼저 시도할 최적화

로컬에서 transformers로 LLM을 돌릴 때 속도가 안 나오면, FlashAttention2는 “설치와 적용만 성공하면 체감 이득이 큰” 최적화 중 하나입니다.

정리하면 순서는 다음이 가장 안전합니다.

  1. PyTorch CUDA 버전과 드라이버 정합성 확인
  2. flash-attn 설치 및 import 확인
  3. 모델 로딩에 attn_implementationflash_attention_2로 지정
  4. 긴 프롬프트로 토큰/초 벤치마크
  5. 여전히 느리면 sdpa 비교, dtype, KV cache, 샘플링 설정 점검

이 과정을 거치면 “그냥 느리다”에서 “어느 구간이 느린지, 어떤 옵션이 효과가 있는지”로 문제를 구체화할 수 있고, 이후 양자화나 TensorRT 같은 더 큰 최적화로 확장하기도 쉬워집니다.