- Published on
Transformers 로컬 LLM 느림? FlashAttention2 적용
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
로컬 GPU에서 transformers로 LLM을 돌리다 보면, 같은 모델·같은 프롬프트인데도 토큰 생성 속도(tokens/s)가 기대보다 낮게 나오는 경우가 많습니다. 원인은 다양하지만, 가장 흔한 병목은 어텐션(attention) 계산이 메모리 대역폭에 묶이는 것입니다.
이 글에서는 FlashAttention2(이하 FA2) 를 transformers에 적용해 로컬 LLM 추론을 가속하는 방법을 정리합니다. 단순히 설치만 하고 끝나는 게 아니라, 실제로 속도가 안 오르는 케이스(버전 불일치, dtype, 디바이스 매핑, KV 캐시, VRAM OOM 등)까지 함께 다룹니다.
또한 VRAM이 빡빡해지며 OOM이 나는 상황은 LLM/이미지 모두 자주 겪는 문제라서, 메모리 관점의 접근은 Stable Diffusion 4K 업스케일, VRAM OOM 피하는 법도 같이 참고하면 도움이 됩니다.
왜 로컬 LLM이 느린가: 어텐션이 대부분을 먹는다
LLM 추론의 비용은 크게 2가지로 나뉩니다.
- 프리필(prefill): 입력 시퀀스 전체에 대해 한 번에 계산(처음 한 번 느림)
- 디코딩(decode): 토큰을 하나씩 생성하며 반복(토큰마다 반복)
특히 긴 컨텍스트에서 프리필은 어텐션 계산이 매우 무거워지고, 디코딩은 KV 캐시를 쓰더라도 매 토큰마다 어텐션 연산과 메모리 접근이 반복됩니다.
기본 PyTorch 어텐션은 기능적으로는 좋지만, GPU에서 메모리 이동이 많고 커널 호출이 분절되는 경우가 있어 대역폭 병목이 쉽게 발생합니다. FA2는 이 부분을 커널 퓨전과 메모리 최적화로 개선해 속도를 올립니다.
FlashAttention2가 해결하는 것과 한계
FA2가 잘 먹히는 경우
- NVIDIA GPU에서 FP16/BF16 기반 추론
- 긴 컨텍스트(프리필 가속 체감 큼)
- 디코딩에서도 일정 수준의 토큰/s 개선
FA2만으로 해결 안 되는 경우
- CPU 추론
- 이미
torch의scaled_dot_product_attention가 최적 경로로 잡히는 조합 - 모델이 FA2를 지원하지 않거나(구조/버전), 로딩 옵션이 잘못된 경우
- 병목이 어텐션이 아니라 샘플링/로짓 후처리, 토크나이저, I/O 인 경우
적용 전 체크리스트(가장 중요)
FA2는 “설치만 하면 자동으로 빨라지는” 경우도 있지만, 로컬 환경에서는 아래 조합이 맞아야 합니다.
- GPU: NVIDIA 권장(대부분 CUDA 경로)
- CUDA: 환경에 맞는 버전
- PyTorch: CUDA 빌드 버전 일치
transformers:attn_implementation옵션 지원 버전- dtype: 보통
torch.float16또는torch.bfloat16
속도 튜닝하다가 VRAM이 터지면 원인 파악이 어려워지는데, 이때는 쿠버네티스 환경이라면 OOMKilled 진단 방식이 그대로 도움이 됩니다. 메모리 압박 징후를 보는 관점은 EKS Pod CrashLoopBackOff? OOMKilled 진단법도 참고할 만합니다.
설치: flash-attn(FlashAttention2) 제대로 깔기
대부분의 삽질은 설치에서 시작합니다. 핵심은 PyTorch CUDA 버전과 flash-attn 빌드가 맞아야 한다는 점입니다.
1) 가상환경 준비
python -m venv .venv
source .venv/bin/activate
pip install -U pip
2) PyTorch 설치(예시)
아래는 예시이며, 본인 CUDA 환경에 맞춰 설치해야 합니다.
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
설치 후 확인:
python -c "import torch; print(torch.__version__); print(torch.version.cuda); print(torch.cuda.is_available())"
3) transformers 설치
pip install -U transformers accelerate
4) flash-attn 설치
환경에 따라 가장 안전한 방법은 공식 가이드를 따르는 것이지만, 보통은 아래처럼 진행합니다.
pip install -U flash-attn --no-build-isolation
설치 확인:
python -c "import flash_attn; print('flash_attn ok')"
설치가 자주 실패하는 케이스
- 컴파일 툴체인 누락(
gcc,nvcc) - CUDA toolkit 미설치 또는 버전 불일치
- PyTorch가 CPU 빌드인데 CUDA 확장을 설치하려는 경우
이 경우는 에러 로그를 보고, torch.version.cuda와 로컬 CUDA toolkit을 맞추는 것이 우선입니다.
Transformers에서 FA2 활성화: attn_implementation
transformers는 모델 로딩 시 어텐션 구현체를 선택할 수 있습니다. 핵심 옵션은 attn_implementation 입니다.
아래 예시는 AutoModelForCausalLM 기준입니다.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "meta-llama/Llama-2-7b-hf" # 예시
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
attn_implementation="flash_attention_2",
)
prompt = "Explain FlashAttention2 in simple terms."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
use_cache=True,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
옵션 이름/지원 여부 확인
모델/버전에 따라 flash_attention_2를 지원하지 않거나, 내부적으로 다른 경로로 fallback될 수 있습니다. 속도가 그대로라면 “적용이 안 된 것”일 수 있으니, 아래를 꼭 확인하세요.
- 로딩 시 경고/로그
transformers버전 업- 해당 아키텍처가 FA2를 지원하는지
속도 측정: 토큰/s로 전후 비교하기
체감만으로는 튜닝이 어렵습니다. 최소한 토큰/s를 재야 합니다.
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "meta-llama/Llama-2-7b-hf" # 예시
def bench(attn_impl):
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
attn_implementation=attn_impl,
)
prompt = "Write a short technical note about GPU attention kernels."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# 워밍업
with torch.inference_mode():
_ = model.generate(**inputs, max_new_tokens=32, do_sample=False)
torch.cuda.synchronize()
start = time.time()
new_tokens = 256
with torch.inference_mode():
out = model.generate(**inputs, max_new_tokens=new_tokens, do_sample=False)
torch.cuda.synchronize()
elapsed = time.time() - start
total_len = out.shape[-1]
prompt_len = inputs["input_ids"].shape[-1]
gen_len = total_len - prompt_len
tps = gen_len / elapsed
return tps
for impl in ["eager", "flash_attention_2"]:
try:
tps = bench(impl)
print(impl, "tokens/s:", round(tps, 2))
except Exception as e:
print(impl, "failed:", e)
eager: 기본(비최적) 경로 비교용flash_attention_2: FA2
여기서 차이가 없다면, FA2가 적용되지 않았거나 병목이 다른 곳에 있을 확률이 큽니다.
자주 겪는 문제와 해결
1) attn_implementation을 줬는데도 느리다
가능성이 큰 순서대로:
transformers가 오래된 버전이라 옵션이 무시됨- 모델이 해당 어텐션 구현을 지원하지 않아 fallback
- dtype이 FP32로 올라가서 커널 최적화가 깨짐
device_map이 잘못되어 일부 레이어가 CPU에 올라감
체크 코드:
import torch
print("dtype:", next(model.parameters()).dtype)
print("device:", next(model.parameters()).device)
print("cuda:", torch.cuda.get_device_name(0))
2) VRAM OOM이 난다(FA2 적용 후 더 자주)
FA2 자체가 항상 메모리를 더 쓰는 건 아니지만, 다음 조합에서 OOM이 쉽게 납니다.
- 컨텍스트 길이를 크게 올림
- 배치 크기 증가
max_new_tokens증가- KV 캐시 사용(
use_cache=True)으로 디코딩 메모리 증가
대응책:
max_new_tokens/입력 길이 제한- 배치 줄이기
torch_dtype=torch.float16또는 BF16로 고정- 필요하면 4bit/8bit 양자화(
bitsandbytes) 고려
VRAM 관리 원칙은 이미지 생성에서도 동일합니다. “왜 OOM이 나는지”를 구조적으로 보는 방법은 Stable Diffusion 4K 업스케일, VRAM OOM 피하는 법에서 소개한 접근과 거의 같습니다.
3) Windows에서 설치가 까다롭다
flash-attn은 플랫폼/컴파일 환경 영향을 크게 받습니다. Windows에서 소스 빌드가 막히면 다음 우회가 실무적으로 흔합니다.
- WSL2에서 CUDA 사용
- Linux 머신/서버에서 개발
- 컨테이너 환경으로 고정
4) 프리필만 빨라지고 디코딩은 별 차이 없다
정상일 수 있습니다.
- 긴 입력(프리필)에서는 이득이 크게 나고
- 짧은 입력 + 짧은 디코딩에서는 샘플링/후처리 비중이 커져 이득이 줄어듭니다.
이때는 다음도 같이 보세요.
do_sample=False로 비교(샘플링 비용 제거)use_fast=True토크나이저- 출력 스트리밍/로깅이 병목인지
FA2 적용 외 추가로 같이 하면 좋은 튜닝
1) torch.compile은 신중히
PyTorch 2.x의 torch.compile이 일부 워크로드에서 성능 향상을 주기도 하지만, LLM generate 루프에서는 오히려 초기 컴파일 오버헤드나 그래프 브레이크로 기대만큼 이득이 없을 수 있습니다. FA2는 상대적으로 “바로 체감”되는 경우가 많아 우선순위가 높습니다.
2) KV 캐시와 컨텍스트 전략
- 대화형 챗봇: KV 캐시 유지가 필수
- 배치 추론: 입력 길이·배치·출력 길이의 곱이 VRAM을 지배
운영 환경에서 메모리 압박이 누적되면 결국 프로세스가 죽거나 재시작 루프를 타는데, 이런 “터지기 직전 징후”를 잡는 관점은 인프라에서도 동일합니다. 쿠버네티스라면 EKS Pod CrashLoopBackOff? OOMKilled 5분 진단 같은 체크리스트를 응용해 로컬에서도 메모리/로그를 체계적으로 보세요.
3) 양자화와의 조합
4bit/8bit 양자화는 VRAM을 크게 줄여주지만, 커널 조합에 따라 속도가 오히려 떨어지거나(특히 작은 배치), 특정 GPU에서 효율이 달라질 수 있습니다.
- 목표가 “단일 스트림 최대 토큰/s”인지
- 목표가 “같은 VRAM에서 더 큰 모델”인지
를 먼저 정하고 선택하는 것이 좋습니다.
결론: 로컬 LLM 속도 문제의 1순위 처방
로컬에서 transformers 기반 LLM이 느릴 때, 가장 먼저 확인할 최적화는 다음 3가지입니다.
- dtype을 FP16/BF16로 고정했는가
- 모델이 GPU에 온전히 올라갔는가(
device_map확인) attn_implementation으로 FA2가 실제 적용됐는가
FA2는 특히 긴 컨텍스트에서 프리필을 확실히 끌어올려 “로컬 LLM이 답답한 느낌”을 줄여줍니다. 다만 설치/버전 호환이 까다로운 편이므로, 위의 체크리스트와 벤치 코드로 적용 여부를 숫자로 검증하면서 진행하는 것을 권장합니다.