Published on

Transformers 로컬 LLM 속도 2배 - FlashAttention2 적용

Authors

로컬에서 LLM을 돌릴 때 체감 성능을 좌우하는 건 결국 토큰 생성 속도(tokens/sec) 입니다. 같은 GPU라도 설정에 따라 2배 가까이 차이가 나는 경우가 흔한데, 그중 가장 강력한 레버 중 하나가 FlashAttention2 입니다.

이 글에서는 Hugging Face Transformers 기반 로컬 추론 환경에서 FlashAttention2를 적용해 속도를 끌어올리는 과정을, 설치부터 코드 적용, 벤치마크, 트러블슈팅까지 한 번에 정리합니다.

참고: FlashAttention2는 “모든 상황에서 무조건 2배”가 아니라, 시퀀스 길이, 배치 크기, 모델 구조, dtype, GPU 아키텍처에 따라 이득 폭이 달라집니다. 다만 로컬 LLM 추론에서 가장 흔한 병목인 attention 연산을 크게 최적화해 주기 때문에, 잘 맞으면 매우 큰 체감 향상을 줍니다.

FlashAttention2가 왜 빠른가

일반적인 attention은 QK^T 계산 후 softmax, 그리고 softmax(QK^T)V를 수행합니다. 이 과정에서 다음 문제가 생깁니다.

  • 메모리 대역폭 병목: 중간 행렬(특히 attention score)이 커서 HBM 읽기/쓰기가 폭증
  • 커널 런치 오버헤드: 연산이 여러 커널로 쪼개져 GPU 활용률이 떨어짐
  • 수치 안정성/정밀도 처리 비용: mixed precision에서 softmax 안정화 등 추가 비용

FlashAttention2는 attention을 타일링(tile) 하여 중간 행렬을 메모리에 크게 쓰지 않고, 가능한 한 on-chip(SRAM)에서 스트리밍 방식으로 처리합니다. 결과적으로 메모리 트래픽이 줄고, 커널이 더 효율적으로 합쳐져 속도와 메모리 사용량이 동시에 개선되는 경우가 많습니다.

적용 전 체크리스트

FlashAttention2를 적용하기 전에 환경을 먼저 확인하세요.

1) GPU/드라이버/파이토치 조합

  • NVIDIA GPU 권장(주로 Ampere 이후에서 이득이 큼)
  • CUDA/드라이버는 PyTorch 빌드와 호환되어야 함

다음으로 현재 환경을 확인합니다.

python -c "import torch; print(torch.__version__); print(torch.version.cuda); print(torch.cuda.get_device_name(0))"

2) Transformers 버전

FlashAttention2 연동은 Transformers에서 점점 안정화되었습니다. 너무 오래된 버전이면 옵션이 없거나 동작이 불안정할 수 있어, 가급적 최신 계열을 권장합니다.

python -c "import transformers; print(transformers.__version__)"

3) dtype과 사용 시나리오

추론에서는 보통 torch.float16 또는 torch.bfloat16을 씁니다. FlashAttention2는 이 경로에서 효과가 좋습니다.

  • 단일 요청이라도 max_new_tokens가 크거나 context length가 길면 이득이 커짐
  • 배치 추론(여러 프롬프트 동시 처리)이면 이득 폭이 더 커질 수 있음

설치: flash-attn 준비

FlashAttention2는 일반적으로 flash-attn 패키지로 제공합니다. 다만 CUDA/컴파일 환경에 따라 설치가 까다로울 수 있습니다.

1) 기본 설치(가능하면 이 경로부터)

pip install -U flash-attn --no-build-isolation
  • --no-build-isolation은 빌드 과정에서 PyTorch/CUDA 탐지를 더 안정적으로 만드는 데 도움이 되는 경우가 많습니다.

2) 설치가 실패할 때 점검 포인트

설치 실패는 보통 다음 원인으로 발생합니다.

  • PyTorch의 CUDA 버전과 시스템 CUDA 툴킷 불일치
  • nvcc 또는 컴파일러 툴체인 부재
  • Python, pip, wheel 버전 문제

에러 로그에 CUDA_HOME 관련 메시지가 나오면 다음을 확인하세요.

echo $CUDA_HOME
nvcc --version

로컬 GPU 서버/워크스테이션에서 이런 류의 “환경 의존 빌드 실패”는 흔합니다. 쿠버네티스에서 추론 파드를 운영한다면, 이미지 레벨에서 CUDA/PyTorch 조합을 고정하고 캐시를 잘 설계하는 편이 안정적입니다. (CI 캐시 설계는 GitHub Actions 캐시 미적중? 키 설계 7원칙도 함께 참고할 만합니다.)

Transformers에서 FlashAttention2 적용(핵심)

Transformers는 모델 로딩 시 attention 구현체를 선택할 수 있습니다. 가장 흔한 적용 방식은 attn_implementation 옵션을 사용하는 것입니다.

아래 예시는 causal LM(예: Llama 계열)을 로컬에서 로드하고, FlashAttention2를 켠 뒤 텍스트 생성 속도를 측정합니다.

예제 1: AutoModelForCausalLM + FlashAttention2

import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Llama-2-7b-hf"  # 예시

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="cuda",
    attn_implementation="flash_attention_2",
)

prompt = "로컬 LLM 추론 속도를 높이는 방법을 3가지로 정리해줘."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# 워밍업
with torch.inference_mode():
    _ = model.generate(**inputs, max_new_tokens=32)

torch.cuda.synchronize()

# 측정
max_new_tokens = 256
start = time.time()
with torch.inference_mode():
    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        use_cache=True,
    )

torch.cuda.synchronize()
elapsed = time.time() - start

generated = out.shape[-1] - inputs["input_ids"].shape[-1]
print("generated_tokens:", generated)
print("sec:", elapsed)
print("tokens/sec:", generated / elapsed)

포인트

  • attn_implementationflash_attention_2를 지정
  • torch.inference_mode()로 autograd를 끄고 오버헤드를 줄임
  • torch.cuda.synchronize()로 GPU 비동기 실행에 의한 측정 오차를 줄임

예제 2: pipeline 사용 시

pipeline은 편하지만 세밀한 제어가 어려울 수 있습니다. 그래도 모델 로딩에 동일하게 옵션을 줄 수 있습니다.

import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Llama-2-7b-hf"  # 예시

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="cuda",
    attn_implementation="flash_attention_2",
)

pipe = pipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
)

res = pipe("한 문장으로 FlashAttention2를 설명해줘.", max_new_tokens=64)
print(res[0]["generated_text"])

진짜로 2배 빨라졌는지: 벤치마크 방법

속도 비교는 “한 번 돌려보고 빠른 것 같다”로 끝내면 재현이 어렵습니다. 다음 기준을 맞추면 비교가 깔끔해집니다.

1) 동일 조건 고정

  • 동일 모델/동일 프롬프트/동일 max_new_tokens
  • 동일 dtype(float16 또는 bfloat16)
  • 동일 use_cache 설정
  • 동일 temperature 및 샘플링 설정

2) 워밍업 필수

첫 실행은 CUDA 커널 로딩, 메모리 할당, 캐시 준비 등으로 느릴 수 있습니다.

  • 워밍업 1~3회
  • 측정은 5회 정도 반복 후 평균/중앙값

3) 지표는 tokens/sec로

요청당 시간도 의미 있지만, 생성 토큰 수가 달라지면 왜곡됩니다. tokens/sec를 기본 지표로 삼는 게 좋습니다.

자주 겪는 문제와 해결

FlashAttention2 적용에서 흔히 만나는 이슈를 정리합니다.

1) flash_attn import 에러

증상: ModuleNotFoundError 또는 로딩 시점에 flash_attn 관련 에러

해결:

  • pip show flash-attn로 설치 여부 확인
  • PyTorch CUDA 버전과 호환되는지 재확인
  • 가상환경을 새로 만들고, PyTorch를 먼저 설치한 뒤 flash-attn을 설치

2) dtype 불일치 또는 정밀도 문제

증상: 특정 dtype에서만 느리거나 에러

해결:

  • 추론은 float16 또는 bfloat16을 권장
  • 모델 로딩 시 torch_dtype를 명시
  • GPU가 bfloat16에 강한지(아키텍처) 확인

3) 기대만큼 빨라지지 않는 경우

원인 후보:

  • 시퀀스 길이가 짧아 attention 최적화 이득이 작음
  • 배치가 1이고, 병목이 attention이 아니라 다른 부분(디코딩/샘플링/CPU 토크나이즈)일 수 있음
  • use_cache=False로 설정되어 매 토큰마다 전체 컨텍스트를 다시 계산

개선 팁:

  • use_cache=True 유지
  • 토크나이저 병목이 의심되면 use_fast=True 사용
  • 가능한 경우 프롬프트를 배치로 묶어 처리

4) OOM이 줄었는데도 여전히 터지는 경우

FlashAttention2는 메모리를 줄여주기도 하지만, 모델 크기 자체가 크면 여전히 OOM이 납니다.

  • max_new_tokens를 줄이거나
  • 더 작은 모델/더 낮은 dtype/양자화로 전환하거나
  • KV 캐시가 커지는 상황(긴 컨텍스트, 큰 배치)을 줄여야 합니다.

GPU 메모리 관점에서 문제를 체계적으로 접근하려면 OOM 원인 분해가 중요합니다. 쿠버네티스 환경에서의 메모리 튜닝 관점은 Kubernetes OOMKilled 메모리 튜닝 실전 가이드도 함께 참고하면 도움이 됩니다.

운영 관점 팁: “로컬”이라도 결국은 시스템 문제

로컬 LLM 추론은 모델/커널만 빠르다고 끝나지 않습니다. 실제 서비스나 내부 도구로 붙이면 다음이 성능을 흔히 갉아먹습니다.

  • 동시성 증가로 인한 GPU 큐잉
  • CPU 토크나이즈/후처리 병목
  • 네트워크/프록시 레이어에서의 지연
  • 컨테이너 리소스 제한으로 인한 스로틀링

특히 쿠버네티스에서 mTLS나 프록시를 얹으면 “모델은 빠른데 체감이 느린” 상황이 생길 수 있습니다. 네트워크 계층 이슈가 의심되면 Kubernetes 서비스메시 mTLS 실패 원인 7가지처럼 원인군을 나눠 점검하는 접근이 유효합니다.

결론: FlashAttention2는 가장 가성비 좋은 추론 최적화 중 하나

Transformers 기반 로컬 LLM에서 FlashAttention2는 적용 난이도 대비 효과가 큰 편입니다. 정리하면 다음 순서로 접근하는 것이 안전합니다.

  1. PyTorch/CUDA/드라이버 조합 확인
  2. flash-attn 설치
  3. 모델 로딩 시 attn_implementationflash_attention_2로 지정
  4. 워밍업 후 tokens/sec로 벤치마크
  5. 기대 이득이 작다면 시퀀스 길이, use_cache, 토크나이저/CPU 병목을 재점검

다음 단계로 더 공격적인 최적화를 원한다면, 양자화(예: 4bit), speculative decoding, continuous batching 같은 전략과 함께 조합하는 것이 보통의 “2배”를 “그 이상”으로 끌어올리는 현실적인 루트입니다.