Published on

Stable Diffusion VRAM OOM 12가지 최적화

Authors

Stable Diffusion을 돌리다 보면 가장 흔한 장애가 CUDA out of memory(VRAM OOM)입니다. 특히 1024x1024, 배치 증가, ControlNet·LoRA 다중 적용, 고해상도 업스케일을 섞는 순간 VRAM이 급격히 치솟습니다.

이 글은 “VRAM이 왜 터지는지”를 구성요소별로 나눠 보고, 재현 가능한 형태로 12가지 최적화를 제시합니다. 목표는 단순히 해상도를 낮추는 꼼수가 아니라, 같은 품질을 유지하면서 피크 VRAM을 낮추는 것입니다.

참고로 시스템 RAM이 부족해 스왑이 과도하거나 프로세스가 강제 종료되는 케이스는 VRAM OOM과 증상이 비슷합니다. 리눅스 환경에서 프로세스가 갑자기 죽는다면 리눅스 OOM Killer로 프로세스 죽을 때 원인 추적도 함께 확인하세요.

VRAM OOM이 생기는 구조: 어디서 메모리가 터지나

Stable Diffusion의 메모리 사용은 크게 다음으로 나뉩니다.

  • UNet 활성화(activations): 샘플링 단계에서 가장 큰 비중. 해상도, 배치, 스텝, attention 구현에 따라 피크가 변합니다.
  • Attention 키·쿼리·밸류 텐서: 해상도 증가에 매우 민감. 특히 1024 이상에서 급격히 증가.
  • VAE 디코드/인코드: latent를 이미지로 변환할 때 피크가 튈 수 있음.
  • 모델 가중치(UNet, Text Encoder, VAE): fp16, bf16, fp32 여부와 로딩 방식에 따라 상주 VRAM이 변함.
  • 추가 모듈: ControlNet, IP-Adapter, 여러 LoRA, T2I-Adapter 등.

따라서 최적화도 “가중치 상주량 줄이기”, “attention 피크 줄이기”, “활성화 저장 줄이기”, “파이프라인을 분리해 피크 분산하기”로 나누는 게 효과적입니다.

최적화 1) 해상도·배치·스텝의 VRAM 영향부터 정량화

가장 먼저 해야 할 일은 감으로 튜닝하지 않고, 피크 VRAM을 수치로 확인하는 것입니다.

  • VRAM은 대체로 width * height에 비례하고, attention은 더 가파르게 증가합니다.
  • batch_size는 거의 선형으로 증가합니다.
  • steps는 “누적”이 아니라 “반복”이므로 보통 피크에는 큰 영향이 없지만, 일부 구현(특히 특정 최적화 비활성)에서는 캐시/버퍼로 피크가 달라질 수 있습니다.

PyTorch에서 피크 VRAM 측정 코드

import torch

def vram_peak_mb():
    return torch.cuda.max_memory_allocated() / 1024 / 1024

torch.cuda.reset_peak_memory_stats()
# ... inference 실행 ...
print(f"peak VRAM: {vram_peak_mb():.1f} MB")

이 측정이 있어야 아래 최적화가 “효과가 있었는지” 바로 판단됩니다.

최적화 2) fp16 또는 bf16로 가중치·연산 정밀도 낮추기

fp32로 돌면 대부분의 GPU에서 VRAM이 너무 빡빡합니다. 일반적으로 다음 우선순위를 권장합니다.

  • Ampere 이후(RTX 30, A10, A100)는 fp16이 무난
  • Ada(RTX 40)는 fp16 + 최적화 조합이 좋음
  • 일부 환경에서는 bf16이 안정적(오버플로 이슈 감소)

diffusers 파이프라인 예시

import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
).to("cuda")

주의: fp16은 품질보다 안정성 이슈가 가끔 있습니다. 특정 VAE에서 NaN이 나면 VAE만 fp32로 두는 방식도 고려하세요.

최적화 3) xFormers 또는 SDPA로 attention 메모리 절감

OOM의 핵심은 attention 텐서 폭증입니다. 이를 줄이는 대표 옵션이 xFormers와 PyTorch의 SDPA입니다.

  • xFormers: 메모리 효율적인 attention 구현. 설치가 번거로울 수 있으나 효과가 큼.
  • SDPA: PyTorch 2 계열에서 기본 제공되는 최적화 경로. 환경에 따라 가장 안정적.

diffusers에서 xFormers 활성화

pipe.enable_xformers_memory_efficient_attention()

PyTorch SDPA를 쓰도록 유도(환경에 따라)

import torch
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)

이 한 방으로 1024에서 OOM이 나던 환경이 살아나는 경우가 많습니다.

최적화 4) VAE 타일링으로 디코드 피크 줄이기

고해상도에서 마지막 VAE 디코드 단계가 VRAM 피크를 만들기도 합니다. 이때 VAE tiling이 효과적입니다.

pipe.enable_vae_tiling()

타일링은 속도는 다소 느려지지만, VRAM을 크게 절약합니다.

최적화 5) VAE 슬라이싱으로 메모리 분할

VAE를 한 번에 처리하지 않고 슬라이스 단위로 처리합니다.

pipe.enable_vae_slicing()

타일링과 슬라이싱은 함께 쓰기도 하고, 상황에 따라 하나만 써도 됩니다.

최적화 6) CPU 오프로딩으로 상주 VRAM 줄이기

VRAM이 작은 GPU(6GB, 8GB)에서 가장 현실적인 방법 중 하나가 오프로딩입니다.

  • model CPU offload: 사용하지 않는 모듈을 CPU로 내림
  • sequential CPU offload: 더 공격적으로 단계별로 내림(속도 희생)
pipe.enable_model_cpu_offload()
# 또는
pipe.enable_sequential_cpu_offload()

오프로딩은 PCIe 전송 때문에 느려질 수 있지만, “아예 OOM으로 실패”보다 훨씬 낫습니다.

최적화 7) torch.compile은 상황별로: 속도만이 아니라 메모리도 변한다

PyTorch 2의 torch.compile은 속도 개선이 주목적이지만, 그래프 캡처/커널 융합으로 메모리 패턴도 달라집니다.

  • 어떤 모델은 피크가 줄고
  • 어떤 모델은 컴파일 캐시/버퍼로 피크가 늘 수 있습니다
import torch
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead")

권장: 컴파일 전후로 반드시 max_memory_allocated를 비교하세요.

최적화 8) 배치 대신 “시드 반복 생성”으로 작업 방식 바꾸기

batch_size=2로 한 번에 2장을 뽑는 것과, batch_size=1로 2번 돌리는 것은 총 시간은 다를 수 있어도 피크 VRAM은 전자가 훨씬 불리합니다.

  • OOM이 자주 난다면 batch_size=1 고정
  • 여러 장이 필요하면 루프로 반복 생성
images = []
for i in range(4):
    out = pipe(prompt, num_inference_steps=30, generator=torch.Generator("cuda").manual_seed(100+i))
    images.append(out.images[0])

최적화 9) Hires fix(2-pass)에서 업스케일 전략 조정

고해상도는 한 번에 가면 OOM이 납니다. 그래서 흔히 2-pass로 갑니다.

  • 1차: 낮은 해상도에서 구도 생성
  • 2차: latent 업스케일 후 디테일 강화

하지만 2차에서 denoise를 과하게 올리거나, 업스케일 배율을 크게 잡으면 다시 OOM이 납니다.

실전 팁:

  • 1차를 512 또는 640 근처로 고정
  • 2차 업스케일은 1.5x 또는 2x부터
  • 2차 denoise는 보통 0.3 전후부터 시작

즉 “한 번에 1024” 대신 “640 생성 후 960 업스케일” 같은 타협이 VRAM 대비 품질이 좋습니다.

최적화 10) ControlNet·IP-Adapter·LoRA는 ‘개수’가 곧 VRAM이다

추가 모듈은 VRAM 상주량을 늘리고, UNet 연산량도 늘립니다.

  • ControlNet 2개 이상: VRAM 급증
  • IP-Adapter + ControlNet 동시: 피크가 크게 튐
  • LoRA 여러 개: 가중치 자체는 작아도 적용 경로에 따라 메모리/속도 영향

권장 순서:

  1. 기능을 하나씩 켜며 피크 VRAM 측정
  2. 필요한 것만 남기기
  3. 다중 ControlNet이면 해상도나 batch_size를 반드시 낮추기

최적화 11) 메모리 파편화 줄이기: allocator 설정과 워밍업

OOM이 “남은 VRAM이 있는데도” 발생하는 경우는 파편화(fragmentation)일 수 있습니다.

  • PyTorch CUDA allocator가 큰 연속 블록을 못 잡으면 OOM이 납니다.
  • 특히 여러 번 생성하면서 해상도/옵션이 바뀌면 파편화가 심해집니다.

환경 변수로 allocator 튜닝

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128

워밍업(같은 해상도/옵션으로 1회 예열)

_ = pipe("warmup", num_inference_steps=5)

워밍업은 커널/버퍼가 초기화되며 이후 피크가 안정되는 경우가 있습니다.

최적화 12) 실패를 줄이는 운영 팁: 캐시 정리와 프로세스 격리

실무에서 OOM은 “한 번”이 아니라 “가끔” 터지는 형태로 옵니다. 이때는 운영 레벨의 습관이 중요합니다.

  • 생성 작업 전후로 캐시 정리
  • 한 프로세스에서 여러 모델을 번갈아 로딩하지 않기
  • 노트북/서버에서 다른 GPU 작업(브라우저, 게임, 다른 학습)이 끼지 않게 격리

캐시 정리 코드(필요할 때만)

import gc
import torch

gc.collect()
torch.cuda.empty_cache()

empty_cache()는 “사용 중인 텐서”를 없애는 게 아니라, PyTorch가 잡고 있던 캐시를 반환하는 용도입니다. 남발하면 오히려 느려질 수 있으니 OOM 직후 복구나 큰 설정 변경 직후에만 쓰는 편이 낫습니다.

OOM 디버깅 체크리스트(빠른 진단)

아래 순서로 보면 원인 추적이 빨라집니다.

  1. batch_size2 이상인가
  2. 해상도가 1024 이상인가
  3. ControlNet·IP-Adapter·다중 LoRA를 동시에 켰는가
  4. attention 최적화(xFormers 또는 SDPA)가 켜져 있는가
  5. VAE tiling·slicing을 켰는가
  6. 오프로딩을 적용했는가
  7. 파편화가 의심되면 PYTORCH_CUDA_ALLOC_CONF 적용했는가

서버 환경(예: 쿠버네티스)에서 노드 디스크나 메모리 압박이 동반되면 증상이 더 복잡해집니다. 컨테이너/노드 자원 이슈가 함께 의심되면 EKS 노드 디스크 100%로 Pod Evicted 해결법도 같이 점검하면 좋습니다.

결론: “품질 유지”를 위한 추천 조합

마지막으로, VRAM이 빡빡한 환경에서 성공률이 높은 조합을 정리합니다.

  • 기본: fp16 + xFormers(또는 SDPA)
  • 고해상도: VAE tiling 또는 slicing
  • 8GB 이하: model_cpu_offload까지 고려
  • 작업 방식: batch_size=1 고정 + 반복 생성
  • 확장 기능: ControlNet/Adapter는 하나씩 켜며 측정

OOM은 “GPU가 약해서”라기보다, 피크가 어디서 생기는지 모르고 옵션을 쌓아 올려서 생기는 경우가 많습니다. 위 12가지를 체크리스트처럼 적용하면, 같은 GPU에서도 실패율을 크게 낮추고 더 높은 해상도/복잡한 파이프라인을 안정적으로 운용할 수 있습니다.