- Published on
Transformers 로컬 LLM 느림? FlashAttention2 적용법
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
로컬 GPU에서 transformers로 LLM을 돌리는데 토큰 생성 속도가 기대보다 느리면, 가장 먼저 의심할 만한 지점이 어텐션 구현입니다. 기본 scaled dot-product attention은 메모리 트래픽이 크고, 특히 긴 컨텍스트에서 병목이 심해지기 쉽습니다.
FlashAttention2는 어텐션을 타일링하고 GPU SRAM을 적극 활용해 메모리 접근을 줄이는 방식으로, 같은 모델이라도 토큰 생성 속도와 VRAM 사용량을 눈에 띄게 개선하는 경우가 많습니다. 이 글에서는 transformers에서 FlashAttention2를 적용하는 실전 절차와, 적용이 안 되거나 오히려 느려질 때 체크할 포인트를 정리합니다.
참고로, 로컬 추론 최적화는 FlashAttention 외에도 TensorRT, 양자화, 컴파일 등 선택지가 많습니다. ONNX 및 TensorRT로 넘어가다 생기는 삽질 포인트는 별도 글인 PyTorch→ONNX→TensorRT INT8 양자화 오류 해결도 함께 참고하면 좋습니다.
왜 로컬 LLM이 느릴까: 병목을 먼저 분리하기
체감상 “느리다”는 현상은 크게 두 가지로 나뉩니다.
- 프리필(prefill) 구간이 느림: 입력 프롬프트를 한 번에 인코딩하는 단계가 오래 걸립니다. 긴 컨텍스트일수록 두드러집니다.
- 디코딩(decode) 구간이 느림: 토큰을 한 개씩 생성하는 단계가 느립니다. KV 캐시, 배치 크기, 샘플링 설정, GPU utilization에 영향을 받습니다.
FlashAttention2는 주로 어텐션 연산 자체를 최적화하므로, 특히 긴 컨텍스트의 프리필에서 이득이 크고, 디코딩에서도 모델과 환경에 따라 개선이 나올 수 있습니다. 다만 아래 조건이면 효과가 제한되거나 적용 자체가 안 될 수 있습니다.
- GPU가 너무 구형이거나(아키텍처 제약)
- dtype이 맞지 않거나(
fp16또는bf16권장) transformers버전과flash-attn빌드가 안 맞거나- 모델이 FlashAttention2 경로를 지원하지 않거나
적용 전 체크리스트
1) GPU와 CUDA 환경
- NVIDIA GPU 필요(대부분의 경우)
- CUDA Toolkit 및 드라이버 호환
- PyTorch가 해당 CUDA 버전으로 설치되어 있어야 함
아래로 현재 환경을 빠르게 확인합니다.
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_device_capability(0))
get_device_capability는 예를 들어 8.0, 8.6, 9.0 같은 형태로 나오며, 일반적으로 Ampere 이상에서 안정적인 성능 이득을 기대하는 편입니다.
2) dtype과 메모리
FlashAttention2는 보통 float16 또는 bfloat16에서 잘 동작합니다. float32로 돌리면 속도도 느리고 VRAM도 크게 먹습니다.
- 소비자 GPU에서
bf16이 애매하면fp16부터 시도 - 서버 GPU(A100, H100 등)면
bf16도 좋은 선택
3) 모델이 지원하는지
최근 transformers는 많은 디코더 계열 모델에서 attn_implementation 옵션으로 FlashAttention2를 선택할 수 있습니다. 다만 모델 아키텍처나 버전에 따라 지원 여부가 다릅니다.
FlashAttention2 설치: 가장 흔한 실패 지점
FlashAttention2는 보통 flash-attn 패키지로 설치합니다. 문제는 이 패키지가 CUDA 확장 빌드를 포함하기 때문에, PyTorch CUDA 버전, 컴파일러, CUDA Toolkit, 그리고 OS 환경이 어긋나면 설치가 실패하거나 런타임에서 터집니다.
권장: 가상환경에서 버전 고정
python -m venv .venv
source .venv/bin/activate
pip install -U pip
PyTorch 설치 확인
이미 설치되어 있다면 넘어가도 되지만, CUDA 포함 빌드인지 확인하세요.
python -c "import torch; print(torch.__version__, torch.version.cuda)"
flash-attn 설치
환경에 따라 설치 커맨드는 달라질 수 있지만, 기본적으로는 아래를 시도합니다.
pip install -U flash-attn --no-build-isolation
--no-build-isolation은 종종 빌드 의존성 충돌을 줄여줍니다.- 설치가 실패하면 에러 로그에
nvcc,gcc,ninja,CUDA_HOME등이 언급되는지 확인하세요.
설치 확인
python -c "import flash_attn; print('flash_attn ok')"
Transformers에서 FlashAttention2 켜기
핵심은 모델 로딩 시 attn_implementation을 지정하는 것입니다. 또한 dtype을 float16 또는 bfloat16로 맞추고, 가능하면 device_map을 사용해 GPU에 올립니다.
아래 예시는 일반적인 causal LM 추론 코드입니다.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "meta-llama/Llama-2-7b-hf" # 예시
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
attn_implementation="flash_attention_2",
)
prompt = "Explain FlashAttention2 in simple terms."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
적용 여부 확인 팁
- 속도가 눈에 띄게 빨라졌는지(특히 긴 프롬프트에서)
- VRAM 사용량이 줄었는지
- 경고 로그에 “flash attention” 관련 메시지가 있는지
또는 모델 설정에 구현이 기록되는 경우도 있습니다.
print(getattr(model.config, "attn_implementation", None))
벤치마크: 토큰/초로 확인하기
체감만으로 판단하면 샘플링 옵션, 온도, 반복 실행 시 캐시 워밍 등으로 착시가 생깁니다. 아래처럼 토큰 생성 속도를 간단히 재보는 편이 좋습니다.
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "meta-llama/Llama-2-7b-hf" # 예시
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
def load(attn_impl):
return AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
attn_implementation=attn_impl,
)
def bench(model, prompt, new_tokens=256):
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
torch.cuda.synchronize()
t0 = time.time()
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=new_tokens,
do_sample=False,
)
torch.cuda.synchronize()
dt = time.time() - t0
total_tokens = out.shape[-1] - inputs["input_ids"].shape[-1]
return total_tokens / dt
prompt = "Write a detailed technical explanation of KV cache and attention." * 20
model_eager = load("eager")
print("eager tok/s:", bench(model_eager, prompt))
del model_eager
torch.cuda.empty_cache()
model_fa2 = load("flash_attention_2")
print("fa2 tok/s:", bench(model_fa2, prompt))
- 프롬프트를 길게 만들어야 프리필 병목 차이가 잘 드러납니다.
torch.cuda.synchronize()를 넣어야 측정이 정확해집니다.
자주 겪는 문제와 해결 체크포인트
1) 적용했는데도 속도가 그대로인 경우
프롬프트가 너무 짧다
짧은 입력에서는 어텐션 최적화 이득이 작고, 오히려 오버헤드가 더 커 보일 수 있습니다. 긴 컨텍스트에서 다시 측정하세요.
디코딩 병목은 어텐션만의 문제가 아니다
토큰 생성은 다음 요소도 크게 좌우합니다.
- 샘플링 설정(예:
top_p,temperature) - 배치 크기
max_new_tokensuse_cache활성화 여부
generate는 기본적으로 use_cache를 쓰지만, 모델 설정에 따라 달라질 수 있으니 확인합니다.
print("use_cache:", getattr(model.config, "use_cache", None))
2) 설치는 됐는데 런타임 에러가 나는 경우
대표적으로 CUDA 커널 로딩 실패, 심볼 불일치, ABI 문제 등이 있습니다. 이런 경우는 대개 아래 중 하나입니다.
- PyTorch CUDA 버전과
flash-attn빌드 대상 CUDA가 다름 - 드라이버가 너무 낮음
- 여러 CUDA가 섞여
LD_LIBRARY_PATH가 꼬임
해결 전략은 단순합니다.
- PyTorch를 먼저 원하는 CUDA 버전으로 재설치
- 그 다음
flash-attn을 재설치 - 가능하면 깨끗한 가상환경에서 재현
빌드/캐시가 꼬였을 때는 Docker를 쓰는 것도 방법인데, 캐시가 기대대로 동작하지 않으면 Docker BuildKit 캐시가 안 먹는 9가지 원인 같은 체크리스트가 도움이 됩니다.
3) attn_implementation 옵션이 먹지 않는 경우
모델/버전에 따라 옵션명이 다르거나, 해당 아키텍처가 FlashAttention2 경로를 제공하지 않을 수 있습니다.
transformers를 최신으로 업데이트
pip install -U transformers accelerate
- 모델이
flash_attention_2를 지원하는지 릴리즈 노트나 이슈를 확인 - 그래도 안 되면
eager또는sdpa로 바꿔 성능을 비교
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
attn_implementation="sdpa",
)
여기서 sdpa는 PyTorch의 scaled dot-product attention 경로를 의미하며, 환경에 따라 꽤 좋은 성능이 나올 수도 있습니다.
성능을 더 끌어올리는 실전 팁
1) torch.compile은 신중하게
PyTorch 2 계열에서 torch.compile로 이득을 보는 경우도 있지만, 모델/드라이버/커널 조합에 따라 컴파일 시간이 과도하거나 오히려 느려질 수 있습니다. FlashAttention2부터 안정화한 뒤, 별도 실험으로 접근하는 편이 안전합니다.
2) 긴 컨텍스트에서 VRAM이 빡빡하면
FlashAttention2는 메모리 효율이 좋아 긴 입력에서 특히 유리합니다. 그래도 VRAM이 부족하면 아래를 함께 고려합니다.
max_new_tokens를 줄여 디코딩 길이 제한- 더 작은 모델
- 4bit 또는 8bit 양자화(예: bitsandbytes)
3) 로컬 서비스로 올릴 때는 병목이 다른 곳에 생긴다
로컬에서 단일 프로세스로 돌릴 때는 GPU가 병목이지만, API 서버로 감싸면 다음이 병목이 되기도 합니다.
- 토크나이저 CPU 사용률
- 동시 요청 처리 방식
- 스트리밍 응답 구현
웹 프론트가 느려지는 병목을 진단하는 방식은 LLM과는 다르지만, “병목을 측정으로 쪼개서 해결한다”는 접근은 같습니다. 프론트 렌더 병목을 체계적으로 보는 글로는 Next.js INP 폭증? React 렌더 병목 7단계 진단이 참고가 됩니다.
결론: FlashAttention2는 가장 먼저 시도할 최적화
로컬에서 transformers로 LLM을 돌릴 때 속도가 안 나오면, FlashAttention2는 “설치와 적용만 성공하면 체감 이득이 큰” 최적화 중 하나입니다.
정리하면 순서는 다음이 가장 안전합니다.
- PyTorch CUDA 버전과 드라이버 정합성 확인
flash-attn설치 및 import 확인- 모델 로딩에
attn_implementation을flash_attention_2로 지정 - 긴 프롬프트로 토큰/초 벤치마크
- 여전히 느리면
sdpa비교, dtype, KV cache, 샘플링 설정 점검
이 과정을 거치면 “그냥 느리다”에서 “어느 구간이 느린지, 어떤 옵션이 효과가 있는지”로 문제를 구체화할 수 있고, 이후 양자화나 TensorRT 같은 더 큰 최적화로 확장하기도 쉬워집니다.