Published on

PyTorch 8bit 양자화 에러와 속도 최적화

Authors

서빙/배치 추론 비용을 줄이기 위해 8bit 양자화는 거의 필수 옵션이 됐습니다. 특히 LLM 계열에서는 bitsandbytes 기반 8bit 로딩이 널리 쓰이지만, 막상 적용하면 dtype 불일치, 디바이스 매핑 문제, 커널 미지원, 속도 기대치 미달 같은 이슈가 연쇄적으로 터지기 쉽습니다.

이 글은 PyTorch 모델 8bit 양자화에서 자주 발생하는 에러를 빠르게 진단하고, 진짜로 속도가 빨라지도록 설정을 정리한 실전 가이드입니다.

또한 GPU 메모리(OOM)와 성능 트레이드오프는 결국 “병목을 어디서 없애느냐” 문제이므로, VRAM 최적화 관점은 ComfyUI VRAM 폭발? 타일·VAE로 해결하기도 같이 참고하면 도움이 됩니다.


8bit 양자화: 무엇이 빨라지고 무엇이 안 빨라지나

8bit 양자화는 크게 두 가지 효과를 기대합니다.

  1. 메모리 절감: 가중치가 fp16/bf16 대비 대략 절반 수준으로 줄어듭니다.
  2. 대역폭 병목 완화: 큰 모델은 종종 연산보다 메모리 로드가 병목이라, 가중치가 작아지면 유리합니다.

하지만 다음은 오해가 많습니다.

  • 무조건 빨라지지 않습니다. 커널/연산 경로가 8bit에 최적화되지 않으면, 오히려 디퀀타이즈 오버헤드로 느려질 수 있습니다.
  • 모든 레이어가 8bit로 동작하지 않습니다. 일부 연산은 여전히 fp16/bf16로 수행되고, 특히 attention/softmax 계열은 별도 최적화가 필요합니다.

정리하면, 8bit는 “VRAM을 줄이는 효과”는 확실하지만, “속도 향상”은 커널 지원 + 병목 구조 + 배치/시퀀스 길이에 따라 달라집니다.


준비물: 버전/환경 체크리스트

8bit 양자화 문제의 상당수는 환경 불일치에서 시작합니다.

  • NVIDIA GPU + CUDA 환경(대부분의 bitsandbytes 기능은 CUDA 의존)
  • torch, transformers, accelerate, bitsandbytes 버전 호환
  • 드라이버/런타임 불일치(컨테이너에서 특히 흔함)

아래는 최소 점검 코드입니다.

import torch

print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
print("cuda:", torch.version.cuda)
print("gpu:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)

bitsandbytes는 설치는 됐는데 런타임에 CUDA 로딩 실패로 터지는 경우가 많습니다. 에러 메시지에 CUDA Setup failed 류가 보이면, 우선 드라이버와 CUDA 런타임(컨테이너면 베이스 이미지)을 맞추는 게 1순위입니다.


8bit 로딩 기본 패턴(Transformers + bitsandbytes)

가장 흔한 패턴은 transformers에서 8bit 로딩 옵션을 켜는 것입니다.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_id = "gpt2"  # 예시

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)

model.eval()

여기서 자주 생기는 함정은 다음입니다.

  • device_map="auto"가 레이어를 여러 디바이스로 쪼개며, 일부 레이어가 CPU로 떨어지면 속도가 급락합니다.
  • torch_dtype는 “모델의 기본 dtype 힌트”일 뿐, 실제로 8bit로 들어가는 가중치와 혼재됩니다.

자주 터지는 에러 1: dtype 불일치(특히 LayerNorm, Embedding)

증상

  • RuntimeError: expected scalar type Half but found Float
  • 혹은 특정 레이어에서만 fp32 텐서가 섞여 연산이 깨짐

원인

8bit 로딩을 하면 일부 모듈은 내부적으로 fp16/bf16 경로를 타고, 또 일부는 안전을 위해 fp32를 유지합니다. 이때 입력 텐서 dtype, 캐시 dtype, 레이어 dtype이 어긋나면 런타임 에러가 납니다.

해결

  1. 입력을 모델이 기대하는 dtype로 맞추고
  2. autocast를 올바르게 사용하며
  3. 필요 시 특정 모듈을 fp16/bf16로 강제합니다.
import torch

@torch.inference_mode()
def generate(model, tokenizer, prompt: str):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # 8bit 혼합 환경에서는 autocast가 안전장치 역할을 하는 경우가 많습니다.
    with torch.cuda.amp.autocast(dtype=torch.float16):
        out = model.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=False,
        )
    return tokenizer.decode(out[0], skip_special_tokens=True)

추가로, 모델이 bf16을 더 잘 받는 환경(Ampere 이상)이라면 dtype=torch.bfloat16로 바꾸는 것이 dtype 충돌을 줄이는 경우도 있습니다.


자주 터지는 에러 2: device_map 때문에 CPU 오프로딩 발생

증상

  • 에러는 없는데 갑자기 추론이 매우 느려짐
  • 로그에 offload 또는 CPU 이동이 암시됨

원인

VRAM이 부족하면 accelerate가 자동으로 일부 레이어를 CPU로 배치합니다. 8bit로 VRAM을 줄였는데도 여전히 부족하면, CPU 오프로딩이 걸리면서 PCIe 전송이 병목이 됩니다.

해결

  • 가능한 한 단일 GPU에 전부 올리기
  • 시퀀스 길이/배치 줄이기
  • 필요 시 더 공격적인 양자화(4bit) 또는 체크포인팅/플래시 어텐션 병행

디바이스 배치를 확인하는 간단한 방법은 파라미터가 어디에 있는지 보는 것입니다.

from collections import Counter

def count_param_devices(model):
    c = Counter()
    for _, p in model.named_parameters():
        c[str(p.device)] += p.numel()
    return c

print(count_param_devices(model))

cpu 비중이 조금이라도 있으면, 실제 체감 속도는 크게 떨어질 수 있습니다.


자주 터지는 에러 3: bitsandbytes 커널/아키텍처 미지원

증상

  • 특정 GPU에서만 실패
  • matmul 혹은 Linear8bitLt 관련 에러

원인

GPU 아키텍처나 CUDA 조합에 따라 bitsandbytes의 사전 빌드 휠이 완전 호환이 아닐 수 있습니다. 특히 오래된 GPU 또는 특이한 CUDA 조합에서 문제가 납니다.

해결

  • bitsandbytes 최신 버전으로 업데이트
  • CUDA/드라이버 조합을 안정적인 조합으로 맞춤
  • 불가피하면 8bit 대신 fp16/bf16 또는 다른 양자화 백엔드로 전환

실무 팁으로는 “서빙 환경”을 로컬과 다르게 가져가면 재현이 어려워지므로, 컨테이너 기반으로 고정하는 편이 낫습니다. 비용 관점에서 로그가 과도하게 쌓이면 디버깅이 더 어려워지니, 필요하면 CloudWatch Logs 비용 폭증 원인과 절감 10가지처럼 관측 비용도 함께 관리하는 것이 좋습니다.


속도 최적화 1: torch.compile은 “되는 경우에만” 강력

PyTorch 2.x의 torch.compile은 특정 그래프에서 큰 이득을 주지만, 8bit 양자화 모듈은 컴파일 호환성이 떨어질 수 있습니다.

  • 잘 되면 레이턴시가 줄어듭니다.
  • 안 되면 컴파일 시간이 길고, fallback이 발생하거나 에러가 납니다.

테스트는 이렇게 하되, 실패 시 즉시 끄는 전략이 안전합니다.

import torch

try:
    model_compiled = torch.compile(model)
except Exception:
    model_compiled = model

실제 서빙에서는 “컴파일 성공 여부”와 “첫 요청 워밍업 지연”까지 포함해 판단해야 합니다.


속도 최적화 2: FlashAttention/SDPA로 attention 병목 줄이기

LLM 추론은 attention이 병목인 경우가 많습니다. 8bit로 가중치만 줄여도 attention 자체는 크게 안 빨라질 수 있습니다.

가능하면 PyTorch의 SDPA(Scaled Dot Product Attention) 경로를 타도록 설정합니다.

import torch

# PyTorch 2.x에서 SDPA 커널 선택 힌트
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)

모델/환경에 따라 Flash 커널이 비활성화될 수 있으니, 프로파일링으로 실제 커널이 무엇을 타는지 확인해야 합니다.


속도 최적화 3: KV cache와 시퀀스 전략(배치보다 중요한 경우도)

생성 모델에서 속도를 좌우하는 핵심은 종종 KV cache입니다.

  • use_cache=True가 꺼져 있으면 토큰 생성이 기하급수적으로 느려집니다.
  • 긴 컨텍스트에서 프리필(prefill) 단계가 병목이면, 배치/시퀀스 길이 정책을 바꿔야 합니다.
gen_kwargs = {
    "max_new_tokens": 128,
    "do_sample": False,
    "use_cache": True,
}

서빙에서는 “긴 요청 1개”가 전체 워커를 잡아먹는 현상이 생기므로, 큐를 분리하거나 컨텍스트 길이 제한을 두는 것이 체감 성능을 크게 개선합니다.


속도 최적화 4: 측정이 먼저다(프로파일링 최소 세트)

8bit 적용 후 느려졌다면, 원인은 보통 아래 중 하나입니다.

  • CPU 오프로딩
  • 디퀀타이즈 오버헤드
  • attention 커널 비최적
  • 토크나이저/전처리 병목

최소한의 타이밍 측정 코드를 넣어 “어디가 느린지”부터 확정하세요.

import time
import torch

@torch.inference_mode()
def bench(model, tokenizer, prompt: str, n: int = 20):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # 워밍업
    for _ in range(3):
        _ = model.generate(**inputs, max_new_tokens=32, use_cache=True)

    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(n):
        _ = model.generate(**inputs, max_new_tokens=32, use_cache=True)
    torch.cuda.synchronize()
    t1 = time.perf_counter()

    return (t1 - t0) / n

print("avg sec:", bench(model, tokenizer, "hello"))

여기에 PyTorch profiler까지 붙이면 커널 단위로 확인할 수 있습니다.


8bit 품질/안정성 튜닝: llm_int8_threshold의 의미

llm_int8_threshold는 “이상치(outlier) 채널을 fp16로 처리”하는 기준값으로 이해하면 됩니다.

  • 값을 낮추면 fp16 처리 비중이 늘어 품질/안정성이 좋아질 수 있지만, 메모리/속도 이점이 줄어듭니다.
  • 값을 높이면 더 공격적으로 8bit를 쓰지만, 특정 모델에서는 품질 저하나 불안정이 생길 수 있습니다.

실무에서는 다음 순서로 접근하는 편이 안전합니다.

  1. 기본값으로 로딩해 에러 없이 동작하는지 확인
  2. 품질 문제가 의심되면 threshold를 조정
  3. 속도가 안 나오면 attention/디바이스 매핑/배치 정책을 먼저 손봄

체크리스트: “에러 없이 빠르게” 만들기 위한 우선순위

  1. CPU 오프로딩이 없는지 확인 (device_map 결과 점검)
  2. autocast + 입력 dtype 정리로 dtype 에러 제거
  3. attention 커널이 최적 경로(SDPA/Flash)를 타는지 확인
  4. KV cache가 켜져 있는지 확인
  5. torch.compile은 실측 기반으로 선택
  6. 그래도 느리면 8bit가 답이 아닐 수 있으니 4bit 또는 fp16 최적화로 전환 검토

마무리

PyTorch 8bit 양자화는 “VRAM 절감”에는 매우 효과적이지만, “속도 향상”은 자동으로 따라오지 않습니다. 특히 device_map으로 CPU 오프로딩이 발생하거나 attention 커널이 비최적 경로를 타면, 8bit 적용 후에도 느려지거나 오히려 역효과가 날 수 있습니다.

가장 좋은 전략은 (1) 디바이스 배치 확정 → (2) dtype 안정화 → (3) attention/KV cache 최적화 → (4) 실측 기반 튜닝 순서로 접근하는 것입니다. 이 순서대로만 점검해도, 8bit 적용 시 흔한 에러의 대부분을 빠르게 제거하고, 기대했던 비용/성능 효과에 더 가깝게 도달할 수 있습니다.