Published on

Stable Diffusion LoRA 학습 OOM 해결 - xFormers·FP16·캐시

Authors

Stable Diffusion LoRA를 돌리다 보면 가장 흔하게 마주치는 에러가 CUDA out of memory 입니다. 배치 크기를 줄이면 해결될 것 같지만, 실제로는 어텐션 구현, 정밀도(FP16·BF16), 캐시 전략(VAE/텍스트 인코더), 옵티마이저 상태 메모리, 체크포인트/그래디언트 저장 방식이 동시에 얽혀 OOM을 만듭니다.

이 글은 LoRA 학습 파이프라인에서 VRAM이 어디서 터지는지 분해하고, xFormers·FP16·캐시 튜닝을 중심으로 “재현 가능하게” 메모리를 줄이는 체크리스트를 제공합니다.

관련해서 메모리 병목을 캐시 관점으로 접근하는 사고법은 로컬 LLM OOM 글에서도 유사합니다. 필요하면 함께 참고하세요: Transformers 로컬 LLM OOM 해결 - KV캐시·PagedAttention


OOM이 나는 지점부터 정확히 잡기

LoRA 학습에서 VRAM을 크게 쓰는 구간은 대략 아래 순서로 나타납니다.

  1. U-Net forward/backward: 어텐션의 Q/K/V 및 중간 activation 저장이 가장 큼
  2. VAE encode: 이미지 pixel 을 latent로 바꾸는 과정(특히 고해상도)
  3. Text encoder forward: 프롬프트 토큰 임베딩 계산(반복되면 낭비)
  4. Optimizer states: Adam 계열은 파라미터 외에 m/v 상태를 들고 있어 메모리 증가
  5. Checkpointing/EMA: 저장/추적 옵션이 VRAM을 잡아먹는 경우

OOM 메시지에 reservedallocated 가 같이 보이면, “실제로 쓴 메모리” 뿐 아니라 PyTorch 캐싱 할당자가 잡고 있는 메모리도 영향을 줍니다. 따라서 단순히 배치를 줄이는 것보다 어텐션 메모리 최적화 + 혼합정밀 + 캐시 재사용을 먼저 적용하는 편이 효율이 좋습니다.


1) xFormers 또는 SDPA로 어텐션 메모리부터 줄이기

왜 어텐션이 VRAM을 터뜨리나

기본 어텐션은 시퀀스 길이가 커질수록(=해상도 증가로 토큰 수 증가) 메모리 사용량이 급증합니다. Stable Diffusion의 U-Net은 여러 해상도 스케일에서 어텐션을 수행하고, 학습 시에는 backward를 위해 activation을 저장하므로 OOM이 쉽게 납니다.

선택지: xFormers vs PyTorch SDPA

  • xFormers memory efficient attention: 오래전부터 SD 생태계에서 많이 씀
  • PyTorch 2.x SDPA: scaled_dot_product_attention 기반, 커널/드라이버 조합에 따라 성능과 안정성이 좋음

실무적으로는 “둘 중 하나만 제대로 켜도” 체감 VRAM이 크게 내려갑니다.

diffusers 기반 학습에서 켜는 법

아래는 diffusers U-Net에 xFormers를 적용하는 전형적인 코드입니다.

import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")

# xFormers
pipe.unet.enable_xformers_memory_efficient_attention()

PyTorch SDPA를 강제하려면(환경에 따라 다름) 아래처럼 백엔드 플래그를 조절합니다.

import torch

torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)

체크 포인트

  • xFormers는 CUDA 버전, PyTorch 버전, GPU 아키텍처에 따라 설치가 까다로울 수 있습니다.
  • SDPA는 PyTorch 2.x에서 기본 경로가 되기도 하지만, 커널 선택이 자동이라 “메모리 효율”이 항상 보장되진 않습니다.

권장 순서는 다음입니다.

  1. xFormers 적용 시도
  2. 실패하거나 불안정하면 PyTorch 2.x SDPA 튜닝
  3. 둘 다 애매하면 해상도/배치/체크포인팅으로 보완

2) FP16·BF16 혼합정밀로 activation과 optimizer 부담 줄이기

LoRA는 “학습 파라미터 수가 적다”는 장점이 있지만, activation 메모리는 여전히 U-Net 크기와 해상도에 의해 결정됩니다. 혼합정밀은 activation을 줄이는 가장 즉각적인 방법입니다.

FP16과 BF16의 선택

  • FP16: VRAM 절감 효과 큼, 다만 오버플로우/언더플로우로 NaN 이 날 수 있음
  • BF16: 범위가 넓어 안정적, Ampere 이상 GPU에서 특히 유리. 다만 환경에 따라 속도/지원이 다름

accelerate에서 mixed precision 설정

accelerate config 로 설정하거나, 코드에서 명시할 수 있습니다.

from accelerate import Accelerator

accelerator = Accelerator(mixed_precision="fp16")  # 또는 "bf16"

학습 스크립트가 diffusers 예제 기반이라면 실행 옵션으로도 자주 받습니다.

accelerate launch train_lora.py \
  --mixed_precision fp16 \
  --gradient_accumulation_steps 2

NaN이 뜰 때의 실전 대응

  • bf16 로 전환(가능한 GPU라면 최우선)
  • 학습률을 낮추고 워밍업을 늘림
  • optimizer를 AdamW 기본에서 메모리 절약형으로 바꾸기 전, 먼저 정밀도 안정화부터 확인

3) “캐시”로 중복 계산을 제거하기: 텍스트 인코더·VAE

OOM은 단순히 “한 번에 너무 많이 올린다” 뿐 아니라, 매 step마다 반복되는 계산을 매번 GPU에서 다시 돌리는 구조 때문에 발생하기도 합니다. LoRA 학습은 특히 데이터셋이 크고 step이 많으니, 중복을 줄이는 것이 누적 효율에 크게 기여합니다.

텍스트 인코더 출력 캐시

프롬프트가 고정(또는 제한된 템플릿)인 학습에서는 텍스트 인코더의 hidden state를 캐시할 수 있습니다.

  • 장점: 텍스트 인코더 forward 비용 및 일부 activation 부담 감소
  • 단점: 프롬프트 변형이 많으면 캐시 적중률이 떨어짐

diffusers/학습 프레임워크마다 옵션 이름이 다르지만, 개념적으로는 아래처럼 “토큰화 결과 및 encoder output”을 재사용합니다.

# 의사 코드: prompt -> tokens -> text_encoder_output 캐시
cache = {}

def get_text_emb(prompt, tokenizer, text_encoder, device):
    if prompt in cache:
        return cache[prompt]

    tokens = tokenizer(
        prompt,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    ).input_ids.to(device)

    with torch.no_grad():
        emb = text_encoder(tokens)[0]

    cache[prompt] = emb
    return emb

중요한 점은 학습 대상이 U-Net LoRA라면 텍스트 인코더는 대개 고정이므로 no_grad 로 안전하게 묶을 수 있다는 것입니다.

VAE latent 캐시

이미지 pixel 을 VAE로 latent로 바꾸는 과정도 매 step 반복되면 비용이 큽니다. 데이터셋이 정적이라면 사전 latent 캐시가 효과적입니다.

  • 장점: 학습 루프에서 VAE encode를 제거해 VRAM과 시간을 절약
  • 단점: 디스크 공간 사용, 증강(augmentation) 전략이 제한될 수 있음

간단한 사전 캐싱 예시는 아래와 같습니다.

import os
import torch
from torchvision import transforms
from PIL import Image

@torch.no_grad()
def cache_latents(image_paths, vae, out_dir, device="cuda"):
    os.makedirs(out_dir, exist_ok=True)
    vae.to(device)
    vae.eval()

    tfm = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

    for p in image_paths:
        img = Image.open(p).convert("RGB")
        x = tfm(img).unsqueeze(0).to(device)
        latent = vae.encode(x).latent_dist.sample() * 0.18215

        name = os.path.splitext(os.path.basename(p))[0]
        torch.save(latent.cpu(), os.path.join(out_dir, f"{name}.pt"))

학습 시에는 이미지 대신 latent를 로드해 바로 U-Net으로 넣습니다.


4) Gradient checkpointing: 계산을 늘리고 VRAM을 줄이기

OOM 해결에서 가장 강력한 레버 중 하나가 gradient checkpointing 입니다. activation을 전부 저장하지 않고, backward 시 일부를 재계산하는 방식이라 VRAM을 크게 줄일 수 있습니다(대신 속도는 느려짐).

diffusers U-Net은 보통 아래처럼 켭니다.

pipe.unet.enable_gradient_checkpointing()

학습 스크립트 옵션으로는 다음 형태가 흔합니다.

accelerate launch train_lora.py \
  --gradient_checkpointing

권장 적용 순서는:

  1. xFormers 또는 SDPA
  2. mixed precision
  3. 그래도 부족하면 gradient checkpointing

5) 배치 크기 대신 gradient accumulation을 쓰기

배치 크기를 1까지 줄이면 학습이 불안정해지거나 수렴이 느려질 수 있습니다. 이때는 micro batch는 작게, 대신 gradient accumulation으로 유효 배치를 키우는 방식이 좋습니다.

accelerate launch train_lora.py \
  --train_batch_size 1 \
  --gradient_accumulation_steps 8

메모리 관점에서 핵심은 “한 번의 forward/backward에 올라가는 샘플 수”가 줄어든다는 점입니다.


6) 옵티마이저와 파라미터 구성: LoRA라도 상태 메모리가 생긴다

LoRA는 기본 모델 파라미터를 동결하므로 전체 Adam 상태를 만들지 않는 편이 일반적이지만, 아래 상황에서는 메모리가 늘 수 있습니다.

  • 실수로 U-Net 전체를 학습 대상으로 잡음
  • text encoder까지 함께 학습
  • rank가 과도하게 큼

점검 포인트:

  • 학습 대상 파라미터 수를 로그로 출력해 확인
  • LoRA rankalpha 를 필요 이상으로 키우지 않기

간단한 파라미터 수 점검 코드:

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
print("trainable", trainable, "/", all_params)

7) PyTorch CUDA 캐시/할당자 튜닝과 메모리 단편화 대응

가끔은 “분명 여유가 있어 보이는데도” OOM이 납니다. 이는 단편화(fragmentation)나 캐싱 할당자 설정 때문일 수 있습니다.

PYTORCH_CUDA_ALLOC_CONF 조정

환경변수로 split 크기를 조절해 단편화를 완화할 수 있습니다.

export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128

상황에 따라 64 또는 256 도 시도합니다. 너무 작거나 크면 오히려 역효과가 날 수 있어, OOM이 특정 step에서만 간헐적으로 터지는 케이스에 특히 유용합니다.

메모리 사용량 로깅

학습 루프에 아래를 넣으면 “언제 급증하는지”를 빨리 찾습니다.

import torch

def log_mem(tag=""):
    a = torch.cuda.memory_allocated() / 1024**2
    r = torch.cuda.memory_reserved() / 1024**2
    print(f"{tag} allocated={a:.1f}MB reserved={r:.1f}MB")

8) 해상도·버킷팅·크롭: OOM을 만드는 진짜 원인

LoRA 학습에서 해상도는 토큰 수(=latent spatial size)에 직접 영향을 주고, 어텐션 메모리를 폭발시킵니다.

  • 512 는 상대적으로 안전
  • 768 이상부터는 설정 조합이 조금만 틀어져도 OOM 가능성이 급상승

실전 팁:

  • 다양한 원본 해상도를 섞는다면 버킷팅(bucket) 으로 비슷한 크리티컬 해상도끼리 묶어 패딩 낭비를 줄이기
  • 랜덤 크롭을 쓰면 증강에는 좋지만, latent 캐시와 상충할 수 있으니 “캐시 전략”과 함께 설계하기

9) 추천 조합 프리셋: VRAM별 빠른 처방

8GB GPU

  • 해상도 512
  • train_batch_size=1
  • gradient_accumulation_steps=4 이상
  • xFormers 또는 SDPA 필수
  • FP16 필수, 가능하면 BF16
  • gradient checkpointing 켜기
  • VAE latent 캐시 적극 권장

12GB GPU

  • 해상도 512 또는 제한적 640
  • xFormers + mixed precision
  • checkpointing은 “부족할 때만”
  • 텍스트 인코더 캐시로 속도/메모리 안정화

24GB GPU

  • 768 도 가능하지만, 데이터셋/버킷/옵션에 따라 OOM 여지 존재
  • 고해상도에서 어텐션 최적화(xFormers/SDPA) 영향이 더 큼

트러블슈팅: 자주 나오는 실패 패턴

xFormers를 켰는데도 OOM

  • 실제로 적용이 안 된 경우가 많습니다. 학습 로그에 xFormers 활성화 메시지가 있는지 확인하세요.
  • PyTorch SDPA 경로로 타는지, math 커널로 강제되어 메모리 이득이 없는지 점검하세요.

FP16에서 NaN 발생

  • bf16 로 전환 가능한지 먼저 확인
  • 학습률을 낮추고, gradient clipping을 도입

특정 step에서만 간헐적으로 OOM

  • 단편화 가능성이 큼: PYTORCH_CUDA_ALLOC_CONF 튜닝
  • validation/generate 샘플링을 학습 중간에 돌린다면, 그 구간에서 no_grad 누락 여부 확인

마무리: OOM 해결은 “레버를 순서대로” 당기는 게임

LoRA 학습 OOM을 가장 효율적으로 줄이는 순서는 보통 아래가 정답에 가깝습니다.

  1. xFormers 또는 SDPA 적용으로 어텐션 메모리 절감
  2. FP16 또는 BF16 혼합정밀로 activation 절감
  3. 텍스트 인코더/ VAE 캐시로 반복 계산 제거
  4. gradient checkpointing으로 추가 절감
  5. 그래도 부족하면 해상도/버킷/배치와 accumulation 재설계

캐시와 메모리 병목을 함께 보는 관점은 LLM에서도 동일하게 통합니다. 더 확장된 캐시 튜닝 아이디어가 필요하면 Transformers 로컬 LLM OOM 해결 - KV캐시·PagedAttention도 같이 보면 도움이 됩니다.

다음 단계로는 “OOM은 해결했는데 느리다”가 문제로 바뀌는데, 그때는 데이터로더 병목, latent 캐시 I/O, 컴파일(torch.compile) 적용 여부까지 성능 프로파일링으로 넘어가면 됩니다.