Published on

SDXL LoRA 학습 OOM·느림 해결 - xFormers·8bit·캐시

Authors

SDXL LoRA 학습을 돌리다 보면 가장 먼저 부딪히는 문제가 CUDA out of memory(OOM)과 “왜 이렇게 느리지?”입니다. SDXL은 베이스 모델 자체가 크고(UNet도 크고, 텍스트 인코더도 2개), 기본 해상도도 높아서 SD 1.5 시절의 감각으로 세팅하면 VRAM이 순식간에 터지거나, 겨우 돌아가도 스텝당 시간이 너무 길어집니다.

이 글은 “학습 품질을 크게 해치지 않으면서” VRAM을 줄이고, 병목을 줄여서 스텝 시간을 단축하는 방법을 xFormers, 8bit 옵티마이저, 캐시(특히 VAE/텍스트 인코더) 중심으로 정리합니다. NaN 폭주 같은 수치 안정성 이슈는 별도 글에서 다룬 적이 있으니, 필요하면 함께 참고하세요: Stable Diffusion LoRA 학습 NaN 폭주 잡는법

SDXL LoRA가 OOM/느린 이유를 먼저 분해하기

SDXL LoRA 학습의 리소스 사용은 대략 아래 4개로 나뉩니다.

  1. UNet forward/backward: 가장 큰 VRAM 소비처입니다. 특히 attention의 QKV와 중간 activation이 큽니다.
  2. Text Encoder 2개(CLIP-L, CLIP-G): 학습에 포함시키면 VRAM과 시간이 크게 증가합니다.
  3. VAE 인코딩: 원본 이미지를 latent로 바꾸는 과정이 CPU/GPU 시간을 먹습니다. 매 스텝마다 하면 느립니다.
  4. 옵티마이저 상태(Adam 계열): 파라미터보다 옵티마이저 상태가 더 큰 경우가 많습니다. 8bit가 여기서 효과가 큽니다.

즉, “OOM”은 보통 UNet activation + optimizer state + (text encoder)가 합쳐져 터지고, “느림”은 VAE 인코딩 반복 + 비효율 attention + 데이터 파이프라인이 겹쳐서 발생합니다.

1) xFormers: 메모리 효율 어텐션으로 OOM 확률을 크게 낮추기

xFormers가 해결하는 것

xFormers의 핵심은 memory-efficient attention입니다. 일반 attention은 중간 행렬을 크게 잡아먹는데, xFormers는 이를 더 효율적으로 계산해 VRAM을 줄이고 종종 속도도 개선합니다(환경에 따라 편차 있음).

특히 SDXL처럼 해상도가 높아 attention map이 커지는 상황에서 체감이 큽니다.

설치/적용 포인트

학습 프레임워크마다 옵션 이름이 다르지만, 보통 아래 중 하나로 켭니다.

  • --xformers
  • --enable_xformers_memory_efficient_attention
  • diffusers 기반이면 pipe.enable_xformers_memory_efficient_attention()

예시(파이썬, diffusers/accelerate 계열):

import torch
from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
)

# xFormers 메모리 효율 어텐션 활성화
pipe.enable_xformers_memory_efficient_attention()

pipe.to("cuda")

주의사항(호환성/성능)

  • PyTorch, CUDA, xFormers 버전 조합이 안 맞으면 설치가 지옥일 수 있습니다. 이때는 pip install xformers가 아니라, 환경에 맞는 휠을 쓰거나 PyTorch 버전을 맞추는 게 정답입니다.
  • 어떤 GPU/드라이버 조합에서는 xFormers가 오히려 느릴 수 있습니다. 이 경우 아래의 “캐시”와 “8bit”로 VRAM을 확보하고 xFormers는 끄는 선택지도 있습니다.

2) 8bit 옵티마이저: VRAM을 가장 ‘안전하게’ 줄이는 방법

왜 8bit가 LoRA에 특히 잘 맞나

LoRA는 학습 파라미터 수가 상대적으로 적지만, Adam 같은 옵티마이저는 보통 파라미터마다 m, v 상태를 들고 있어 VRAM을 꽤 씁니다. 8bit 옵티마이저(bitsandbytes)는 이 상태 메모리를 크게 줄여줍니다.

  • 효과: VRAM 절감
  • 부작용: 환경에 따라 약간의 속도 변화(대개 큰 문제 없음), 설치 이슈 가능

bitsandbytes 8bit Adam 예시

import bitsandbytes as bnb
import torch

# LoRA 파라미터만 optimizer에 넣는다고 가정
optimizer = bnb.optim.AdamW8bit(
    lora_parameters,
    lr=1e-4,
    betas=(0.9, 0.999),
    weight_decay=0.01,
)

scaler = torch.cuda.amp.GradScaler()

학습 스크립트 옵션으로는 보통 아래처럼 켭니다.

  • --optimizer_type AdamW8bit
  • --use_8bit_adam

8bit 적용 체크리스트

  • LoRA만 학습하는데도 OOM이면, 가장 먼저 8bit Adam을 켜는 것이 성공 확률이 높습니다.
  • Windows 환경은 bitsandbytes 호환이 까다로운 경우가 있습니다. WSL2로 옮기면 해결되는 케이스가 많습니다.

3) 캐시: “느림”의 주범인 VAE/텍스트 인코더 반복을 없애기

SDXL LoRA 학습이 느린 가장 흔한 이유는 매 스텝마다 같은 일을 반복하기 때문입니다.

  • 같은 이미지에 대해 VAE 인코딩을 매번 수행
  • 같은 캡션(프롬프트)에 대해 텍스트 인코딩을 매번 수행

이 둘은 “학습 파라미터 업데이트”와 무관하게 입력을 만드는 과정이라, 캐시로 크게 줄일 수 있습니다.

3-1) VAE latent 캐시

원본 이미지를 latent로 바꾼 결과를 디스크나 메모리에 저장해두면, 학습 루프에서 VAE를 거의 호출하지 않아도 됩니다.

장점:

  • 스텝당 시간 감소
  • GPU 사용량이 더 안정적

단점:

  • 캐시 저장 공간 필요
  • 데이터 증강(랜덤 크롭, 컬러 등)을 학습 중에 강하게 쓰는 경우, 캐시 전략을 재설계해야 함

개념 예시(사전 캐시):

from pathlib import Path
import torch

@torch.no_grad()
def cache_latents(dataset, vae, out_dir: str):
    out = Path(out_dir)
    out.mkdir(parents=True, exist_ok=True)

    for i, image in enumerate(dataset.images):
        # image: (C,H,W) float tensor
        latent = vae.encode(image.unsqueeze(0)).latent_dist.sample()
        latent = latent * 0.18215
        torch.save(latent.cpu(), out / f"{i:06d}.pt")

학습 시에는 image 대신 latent를 로드합니다.

3-2) 텍스트 인코더 출력 캐시(SDXL은 2개라 더 중요)

SDXL은 텍스트 인코더가 2개라, 캡션 인코딩 비용이 더 큽니다. 캡션이 고정(혹은 소수)이라면 캐시는 체감이 큽니다.

  • 캐시 대상: prompt_embeds, pooled_prompt_embeds
  • 캐시 키: 캡션 문자열 + 토크나이저 설정 + max_length 등

개념 예시:

import hashlib
import torch

def cache_key(text: str) -> str:
    return hashlib.sha256(text.encode("utf-8")).hexdigest()

@torch.no_grad()
def encode_and_cache(texts, tokenizer_1, tokenizer_2, te1, te2, out_dir: str):
    for t in texts:
        key = cache_key(t)
        # tokenizer/encoder 호출 후 결과 저장
        prompt_embeds, pooled = run_sdxl_text_encoders(
            t, tokenizer_1, tokenizer_2, te1, te2
        )
        torch.save({
            "prompt_embeds": prompt_embeds.cpu(),
            "pooled": pooled.cpu(),
        }, f"{out_dir}/{key}.pt")

3-3) “캐시를 켰는데도 느리다”면: 데이터 로더 병목

캐시를 쓰면 GPU가 더 빨리 다음 배치를 요구하므로, 이제는 CPU/디스크 I/O가 병목이 될 수 있습니다.

  • num_workers 증가
  • pin_memory=True
  • 캐시 파일을 너무 잘게 쪼개지 않기(파일 수가 많으면 메타데이터 I/O가 병목)
  • 가능하면 NVMe 사용

이런 병목 진단은 DB/서버 글이지만 관점은 동일합니다. “리소스가 어디에서 막히는지”를 먼저 확인하는 방식이 중요합니다: Spring Boot 3 가상스레드에서 DB풀 병목 진단법

4) 추가로 잘 먹히는 OOM 완화 옵션들(우선순위 포함)

아래는 xFormers/8bit/캐시 외에, SDXL LoRA에서 자주 쓰는 옵션들입니다.

4-1) mixed precision: fp16 또는 bf16

  • VRAM 절감과 속도 개선에 기본적으로 유리
  • bf16은 지원 GPU에서 안정성이 더 좋은 경우가 있음

CLI 예시:

accelerate launch train_lora_sdxl.py \
  --mixed_precision fp16

4-2) gradient checkpointing

activation을 저장하지 않고 필요 시 재계산해서 VRAM을 줄입니다.

  • 장점: OOM 방지
  • 단점: 속도는 느려질 수 있음

옵션 예시:

--gradient_checkpointing

4-3) 텍스트 인코더 학습을 끄기

SDXL에서 텍스트 인코더까지 학습하면 VRAM/시간이 확 늘어납니다. 스타일 LoRA나 캐릭터 LoRA 대부분은 UNet만으로도 충분한 경우가 많습니다.

옵션 예시:

--train_text_encoder false

4-4) 해상도/버킷/배치/누적 스텝 재조정

OOM은 결국 batch_size * resolution * activation의 함수입니다.

  • batch_size를 1로 낮추고, 대신 gradient_accumulation_steps로 유효 배치를 맞추기
  • 해상도는 SDXL에서 1024가 기본이지만, 데이터가 단순하면 768로 시작해도 학습이 성립하는 경우가 있음

예시:

--train_batch_size 1 \
--gradient_accumulation_steps 4 \
--resolution 1024

5) 추천 조합(실전 프리셋)

8GB VRAM(매우 빡빡)

  • fp16
  • 8bit Adam 필수
  • xFormers 켜기 시도
  • gradient checkpointing 켜기
  • batch_size=1, 누적으로 맞추기
  • VAE latent 캐시 적극 권장

12GB VRAM(가장 흔한 구간)

  • fp16 또는 bf16
  • 8bit Adam 강력 추천
  • xFormers 켜기
  • 텍스트 인코더는 보통 끄고 시작
  • VAE/텍스트 캐시로 속도 최적화

24GB VRAM 이상(여유)

  • xFormers는 켜서 VRAM 마진 확보
  • 캐시로 학습 throughput 올리기
  • 텍스트 인코더 학습은 목적이 명확할 때만(특정 토큰/문구 정밀 튜닝)

6) “무엇부터 켜야 하나?” 우선순위 체크리스트

  1. OOM이면: mixed precision 확인(fp16 또는 bf16)
  2. 다음: 8bit Adam 적용(옵티마이저 상태 VRAM 절감)
  3. 다음: xFormers 적용(attention VRAM 절감)
  4. 그래도 OOM이면: gradient checkpointing
  5. 느리면: VAE latent 캐시, 텍스트 인코더 캐시
  6. 여전히 느리면: DataLoader I/O 튜닝, 캐시 파일 구조 개선

추가로, 메모리 문제를 다룰 때는 “계속 쌓이는 메모리(누수/캐시 미스)”인지 “피크에서 터지는 메모리(순간 OOM)”인지 구분하는 게 중요합니다. 이 관점은 LLM 에이전트 쪽이지만 참고할 만합니다: AutoGPT 메모리 누수 막는 벡터DB TTL·압축

7) 마무리: SDXL LoRA의 정답은 ‘VRAM 절감’과 ‘반복 제거’

SDXL LoRA 학습 최적화는 결국 두 줄로 요약됩니다.

  • VRAM 피크를 낮춰서 OOM을 피한다: xFormers, 8bit Adam, mixed precision, checkpointing
  • 스텝마다 반복되는 전처리를 제거해 throughput을 올린다: VAE latent 캐시, 텍스트 인코더 캐시, DataLoader 튜닝

위 조합으로도 안 되면, 그때는 “학습 목표(스타일/캐릭터/의상/구도) 대비 과한 설정”을 의심해야 합니다. 예를 들어 텍스트 인코더까지 학습하면서 캡션 품질이 낮으면, 비용만 늘고 결과는 나빠질 수 있습니다. 먼저 UNet LoRA + 캐시 + 8bit + xFormers로 안정적인 파이프라인을 만든 다음, 필요할 때만 옵션을 하나씩 추가하는 방식이 가장 빠르게 수렴합니다.