- Published on
Stable Diffusion LoRA 학습 OOM 해결 - xFormers·SDXL
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
Stable Diffusion에서 LoRA를 학습하다 보면 가장 흔히 마주치는 에러가 CUDA out of memory 입니다. 특히 SDXL은 베이스 모델 자체가 크고(UNet/텍스트 인코더/해상도 구성까지) 학습 파이프라인이 복잡해져서, SD 1.5에서 잘 되던 설정이 그대로는 절대 버티지 못하는 경우가 많습니다.
이 글은 “배치 줄여라” 같은 단편 처방이 아니라, 어떤 요소가 VRAM을 먹는지를 분해해서 보고, xFormers 또는 PyTorch SDPA를 포함한 현실적인 조합으로 OOM을 해결하는 방법을 정리합니다. SDXL LoRA 학습을 기준으로 설명하지만, SD 1.5에도 동일한 원리가 적용됩니다.
관련해서 SDXL 파이프라인을 확장(예: ControlNet, IP-Adapter 동시 적용)할 때도 VRAM이 급증하는데, 그때의 감각을 잡는 데는 아래 글도 도움이 됩니다.
OOM의 정체: VRAM은 어디에 쓰이나
학습 시 GPU 메모리는 대략 아래 항목으로 나뉩니다.
- 모델 파라미터(Weights): UNet, 텍스트 인코더, VAE 등
- 옵티마이저 상태(Optimizer states): Adam 계열이면 보통 파라미터의 2배 정도 추가(모멘텀/분산)
- 그래디언트(Gradients): 학습 대상 파라미터에 대해 추가
- 활성값(Activations): forward 중간 결과. 배치/해상도/UNet 블록 깊이에 비례
- 어텐션 중간 버퍼(Attention buffers): 특히 self-attn은 토큰 수가 커지면 급증
- CUDA 캐시/메모리 단편화:
reserved가 커지고allocated는 낮은데도 OOM이 나는 패턴
LoRA는 “전체 파라미터를 학습하지 않는다”는 점에서 2, 3번을 크게 줄여주지만, 4, 5번(활성값/어텐션 버퍼) 는 여전히 큽니다. 즉, LoRA라고 해서 무조건 VRAM이 넉넉해지는 게 아니라, 해상도와 어텐션 구현이 사실상 승패를 좌우합니다.
먼저 확인할 것: 지금 OOM이 단편화인지, 진짜 부족인지
OOM 로그에 allocated 와 reserved 가 함께 나오는 경우가 많습니다.
allocated는 실제 사용 중인 메모리reserved는 PyTorch가 CUDA allocator로 잡아둔 풀
reserved 가 지나치게 크고 allocated 가 상대적으로 작다면, 단편화 또는 캐시 정책 때문에 “남아 보이는데도” OOM이 납니다. 이때는 아래가 효과적입니다.
- 학습 시작 전에 커널/프로세스 재시작(가장 확실)
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True설정- validation 이미지 생성 등 불필요한 inference를 학습 프로세스와 분리
메모리 스냅샷 빠르게 찍기
아래 코드는 학습 루프 중간에 VRAM 사용량을 확인할 때 유용합니다.
import torch
def vram(msg=""):
if not torch.cuda.is_available():
return
a = torch.cuda.memory_allocated() / 1024**3
r = torch.cuda.memory_reserved() / 1024**3
m = torch.cuda.max_memory_allocated() / 1024**3
print(f"[VRAM]{msg} allocated={a:.2f}G reserved={r:.2f}G max={m:.2f}G")
# 예: forward 직후
vram(" after forward")
핵심 1: xFormers로 어텐션 메모리 줄이기
OOM의 주범은 종종 어텐션입니다. 특히 SDXL은 기본 해상도/토큰 구성에서 어텐션 비용이 빠르게 커집니다. xFormers는 메모리 효율적인 attention 커널을 제공해 어텐션 중간 버퍼를 크게 줄이는 효과가 있습니다.
(1) xFormers 설치
환경에 따라 다르지만, 일반적으로는 아래처럼 설치합니다.
pip install -U xformers
CUDA/PyTorch 버전에 따라 휠 호환이 깨질 수 있습니다. 그 경우에는 다음 순서로 점검하세요.
torch.__version__과 CUDA 버전 매칭- xFormers가 해당 torch 버전을 지원하는지 확인
- 안 맞으면 torch를 맞추거나, xFormers 대신 PyTorch SDPA로 전환
(2) diffusers에서 xFormers 켜기
diffusers 기반 학습이라면 보통 아래 한 줄로 적용됩니다.
pipe.unet.enable_xformers_memory_efficient_attention()
훈련 스크립트에 따라 unet 객체를 직접 다루는 경우도 많습니다.
unet.enable_xformers_memory_efficient_attention()
xFormers가 오히려 느리거나 불안정한 경우
간혹 드라이버/커널 조합에서 xFormers가 속도 이점이 없거나, 특정 GPU에서 불안정한 경우가 있습니다. 이때는 PyTorch 2.x의 SDPA(scaled dot-product attention)로 우회할 수 있습니다.
핵심 2: PyTorch SDPA로 대체(또는 병행)하기
PyTorch 2.x는 torch.nn.functional.scaled_dot_product_attention 기반 최적화를 제공합니다. 프레임워크가 자동으로 최적 커널을 선택하게 하거나, 설정으로 유도할 수 있습니다.
SDPA 커널 선택 힌트
아래는 환경에 따라 효과가 갈립니다. 하지만 xFormers가 깨지는 환경에서 “최소한의 메모리 절감”을 얻는 데 도움이 됩니다.
import torch
# 일부 환경에서 메모리 효율 커널을 더 선호하도록 유도
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
주의할 점은, 모델/드라이버 조합에 따라 Flash Attention 계열이 비활성화되거나 정확도/재현성 이슈가 생길 수 있다는 것입니다. 학습이 “돌아가기만 하면 된다”가 목표라면 mem-efficient 쪽을 우선 고려하고, 재현성이 중요하면 설정을 보수적으로 가져가세요.
핵심 3: Gradient Checkpointing은 LoRA에서도 여전히 강력
LoRA는 옵티마이저 상태를 줄여주지만, 활성값 메모리는 크게 줄지 않습니다. Gradient checkpointing은 forward 중간 활성값을 저장하지 않고, backward 때 재계산해서 VRAM을 줄입니다(대신 속도 손해).
diffusers/transformers 계열에서 보통 아래처럼 켭니다.
unet.enable_gradient_checkpointing()
# 텍스트 인코더까지 학습하는 경우
text_encoder.gradient_checkpointing_enable()
SDXL에서 특히 효과가 큰 편이며, 12GB~16GB VRAM에서 “학습이 되느냐 마느냐”를 가르는 옵션이 되기도 합니다.
핵심 4: Mixed Precision과 LoRA dtype 전략
OOM을 줄이는 가장 쉬운 방법 중 하나는 fp16 또는 bf16입니다.
- fp16: 대부분 GPU에서 빠르고 메모리 절감 큼
- bf16: 수치 안정성이 더 좋은 편(지원 GPU 필요)
다만 LoRA 가중치까지 무조건 fp16으로 두면 학습이 불안정해지는 경우가 있어, 다음 전략이 실전에서 자주 쓰입니다.
- UNet forward는 fp16 또는 bf16
- LoRA 파라미터는 fp32 유지(옵션)
프레임워크마다 설정 방식이 다르지만, accelerate 를 쓴다면 보통 런처 옵션으로 제어합니다.
accelerate launch \
--mixed_precision=fp16 \
train_lora.py
bf16이 가능한 환경이면:
accelerate launch \
--mixed_precision=bf16 \
train_lora.py
핵심 5: SDXL에서 가장 위험한 파라미터는 해상도와 배치
SDXL은 1024 계열 해상도가 기본처럼 느껴지지만, 학습에서는 VRAM이 급격히 증가합니다. 특히 다음 조합이 OOM을 잘 유발합니다.
- 높은 해상도(예: 1024)
- 큰 배치
- 많은 토큰 길이
- xFormers/SDPA 미적용
- gradient checkpointing 미적용
실전 권장 조합(예: 12GB~16GB)
- 해상도: 768 또는 512부터 시작
- batch size: 1
- gradient accumulation: 4~16으로 보정
- gradient checkpointing: on
- xFormers 또는 SDPA: on
- optimizer: 8-bit Adam 사용 고려
여기서 중요한 포인트는 배치를 줄이는 대신 gradient accumulation으로 유효 배치를 맞추는 것입니다.
# 예시: 배치 1로 돌리되, 누적 8로 유효 배치 8 느낌
--train_batch_size=1 \
--gradient_accumulation_steps=8
핵심 6: 8-bit 옵티마이저로 옵티마이저 상태 메모리 줄이기
LoRA는 학습 파라미터 수가 적어도, 옵티마이저 상태가 무시 못 할 때가 있습니다(특히 텍스트 인코더까지 학습하거나 rank를 높게 잡았을 때).
bitsandbytes 기반 8-bit Adam은 옵티마이저 상태 메모리를 줄여 OOM을 피하는 데 유용합니다.
pip install -U bitsandbytes
학습 스크립트 옵션 예:
--optimizer=adamw8bit
주의: 일부 환경(특히 Windows, 특정 CUDA 조합)에서는 설치/동작 이슈가 있을 수 있습니다. 그 경우 xFormers/체크포인팅/해상도 조합으로 먼저 안정화한 뒤, 8-bit 옵티마이저는 마지막에 얹는 것이 디버깅이 쉽습니다.
핵심 7: SDXL은 텍스트 인코더 학습이 OOM을 앞당긴다
SDXL 계열 LoRA 학습에서 텍스트 인코더까지 학습하면 품질이 좋아질 때가 있지만, VRAM 사용량이 확 늘 수 있습니다. OOM이 잦다면 우선 아래 우선순위를 추천합니다.
- UNet LoRA만 학습
- 안정화되면 텍스트 인코더를 부분적으로 학습(또는 rank 낮게)
- 그래도 필요하면 두 인코더 모두 학습(설정 난이도 상승)
또한 SDXL은 텍스트 인코더가 2개 구성인 경우가 많아(파이프라인/스크립트에 따라) 메모리 압박이 더 커집니다.
OOM을 줄이는 체크리스트(우선순위 순)
아래는 “효과 대비 부작용이 적은 순서”로 정리한 체크리스트입니다.
- 학습 프로세스 재시작 후 재시도(단편화 제거)
- xFormers 활성화 또는 SDPA 최적화
- gradient checkpointing 활성화
- mixed precision(fp16 또는 bf16)
- 해상도 낮추기(1024
->768->512, 반드시->를 백틱으로 감쌈) - batch size 1 + gradient accumulation으로 유효 배치 확보
- 8-bit optimizer 적용
- 텍스트 인코더 학습 비활성화(또는 범위 축소)
이 흐름은 애플리케이션 OOM을 분석할 때 “힙 덤프 뜨고, 큰 객체부터 줄인다”는 접근과 비슷합니다. 메모리 문제를 체계적으로 쪼개는 감각은 아래 글의 방식도 참고가 됩니다.
예시: diffusers + accelerate 기반 SDXL LoRA 학습 설정 샘플
아래는 “OOM을 피하는 쪽”으로 안전장치를 많이 건 커맨드 예시입니다. 스크립트 옵션은 프로젝트마다 다르니, 개념을 가져가서 본인 학습 코드에 맞게 매핑하세요.
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
accelerate launch \
--mixed_precision=fp16 \
train_sdxl_lora.py \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
--resolution=768 \
--train_batch_size=1 \
--gradient_accumulation_steps=8 \
--learning_rate=1e-4 \
--max_train_steps=2000 \
--enable_xformers \
--gradient_checkpointing \
--optimizer="adamw8bit"
스크립트 내부에서 --enable_xformers 같은 플래그를 받지 않는다면, 파이썬 코드에서 직접 켜면 됩니다.
# train_sdxl_lora.py 내부 예시
# 1) xFormers
try:
unet.enable_xformers_memory_efficient_attention()
except Exception as e:
print("xFormers enable failed, fallback to SDPA:", e)
# 2) gradient checkpointing
unet.enable_gradient_checkpointing()
자주 나오는 OOM 패턴과 처방
패턴 A: 학습 시작 직후 바로 OOM
- 원인: 초기 forward에서 어텐션/활성값이 한 번에 폭발
- 처방: xFormers 또는 SDPA
+gradient checkpointing+해상도 768 이하부터
패턴 B: 몇 step 돌다가 갑자기 OOM
- 원인: validation 생성, 로깅 이미지 생성, 캐시 누적, 단편화
- 처방: 학습 중간 inference를 별도 프로세스로 분리,
PYTORCH_CUDA_ALLOC_CONF적용, 주기적torch.cuda.empty_cache()는 “최후의 수단”으로만(근본 해결은 아님)
패턴 C: reserved 는 큰데 allocated 가 낮아도 OOM
- 원인: 단편화/세그먼트 정책
- 처방: 프로세스 재시작,
expandable_segments적용, 불필요한 텐서 참조 제거
마무리: SDXL LoRA OOM은 ‘옵션 조합’ 싸움이다
SDXL LoRA 학습에서 OOM을 해결하는 가장 현실적인 정답은 하나가 아니라 조합입니다.
- 어텐션 최적화(xFormers 또는 SDPA) 로 큰 덩어리를 줄이고
- gradient checkpointing 으로 활성값을 줄이며
- mixed precision 으로 전체 메모리 풋프린트를 낮춘 뒤
- 해상도/배치/누적 스텝을 조절해 품질과 속도를 맞추는 흐름이 안정적입니다.
이 과정을 거치면 12GB~16GB에서도 SDXL LoRA가 “학습 가능한 영역”으로 들어오는 경우가 많습니다. 다음 단계로는 학습 품질(데이터/캡션/정규화)과 속도(컴파일, 캐시, 데이터로더)가 병목이 되는데, 그때는 로컬 LLM 최적화 글에서 다룬 것처럼 병목을 계측하고 캐시 전략을 세우는 접근이 의외로 그대로 통합니다.