- Published on
SDXL LoRA 학습 OOM·속도저하 해결 7가지
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
SDXL LoRA 학습을 돌리다 보면 가장 먼저 부딪히는 문제가 CUDA out of memory(OOM)와 “왜 이렇게 느리지?”라는 속도 저하입니다. 특히 SDXL은 베이스 모델 자체가 크고(UNet도 무겁고), 해상도와 배치, 옵티마이저 상태까지 겹치면 VRAM이 순식간에 바닥납니다.
이 글은 **OOM을 줄이면서도 학습 속도를 유지(또는 개선)**하기 위한 7가지 실전 체크리스트를 제공합니다. 예시는 주로 diffusers 기반 LoRA 학습을 기준으로 하지만, kohya-ss/sd-scripts 계열에서도 개념은 동일합니다.
1) “진짜 OOM”인지 “메모리 파편화”인지 먼저 구분하기
OOM은 크게 두 종류가 있습니다.
- 진짜 VRAM 부족: 필요한 텐서와 옵티마이저 상태를 합치면 물리적으로 안 들어감
- 메모리 파편화(fragmentation): 남은 VRAM은 충분해 보이는데 연속 블록 할당이 안 되어 실패
파편화는 특히 학습 도중 eval, save, validation 같은 단계에서 텐서가 생성·해제되며 심해질 수 있습니다.
빠른 진단 코드
import torch
def vram_report(tag=""):
torch.cuda.synchronize()
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
print(f"[{tag}] allocated={allocated:.2f}GiB reserved={reserved:.2f}GiB")
vram_report("start")
allocated는 실제 사용 중인 메모리reserved는 캐싱(예약)된 메모리
reserved가 과도하게 큰데 OOM이 난다면 파편화 가능성이 큽니다.
파편화 완화 팁
- 학습 시작 전에 환경변수로 allocator 튜닝
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128
- 검증/샘플링을 자주 돌린다면 주기를 늘리거나, 샘플링을 별도 프로세스로 분리
파편화는 “근본 해결”이 어렵기 때문에, 검증 빈도/배치/해상도 조합을 안정적으로 가져가는 것이 가장 효과적입니다.
2) 해상도·배치·그라디언트 누적을 “VRAM 예산”으로 재설계하기
SDXL LoRA에서 VRAM을 가장 크게 흔드는 3요소는 다음입니다.
resolution(또는crop/bucket)train_batch_sizegradient_accumulation_steps
여기서 핵심은 **유효 배치(effective batch)**를 유지하면서도 “한 번에 GPU에 올리는 배치”를 줄이는 것입니다.
- 유효 배치 =
train_batch_size * gradient_accumulation_steps
예를 들어 유효 배치를 8로 유지하고 싶다면:
train_batch_size=1,gradient_accumulation_steps=8train_batch_size=2,gradient_accumulation_steps=4
대부분의 OOM 상황에서는 전자가 더 안전합니다(대신 step당 오버헤드가 약간 증가).
diffusers 학습 실행 예시
accelerate launch train_text_to_image_lora_sdxl.py \
--pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \
--resolution 1024 \
--train_batch_size 1 \
--gradient_accumulation_steps 8 \
--mixed_precision fp16 \
--learning_rate 1e-4 \
--max_train_steps 4000
추가 팁:
- 1024가 버겁다면
resolution=768로 낮추고, 대신 데이터 버킷팅을 활용해 품질 손실을 줄입니다. - SDXL은 1024에서 급격히 무거워지므로, “OOM이 아슬아슬”하면 768로 내려 안정성을 확보하는 편이 결과적으로 빠릅니다(재시도 비용이 사라짐).
3) mixed precision을 “fp16 고정”이 아니라 상황별로 선택하기
많이들 --mixed_precision fp16을 기본으로 쓰지만, SDXL은 환경에 따라 bf16이 더 안정적일 때가 있습니다.
fp16: VRAM 절약 효과가 크고 빠른 편이지만, 오버/언더플로우로loss가 튀는 경우가 있음bf16: 수치 안정성이 더 좋고, Ampere 이후 GPU에서 성능이 잘 나오는 편
accelerate 설정 예시
accelerate config
# mixed_precision: bf16
혹은 실행 옵션으로:
accelerate launch --mixed_precision bf16 train_text_to_image_lora_sdxl.py ...
추가로, 아래 조합도 자주 씁니다.
- UNet은
fp16/bf16 - VAE는
fp32로 고정(특히 샘플링/검증 시)
VAE를 fp32로 두면 VRAM이 조금 늘 수 있지만, 검증 이미지가 깨지거나 색이 틀어지는 문제를 줄이는 데 도움이 됩니다.
4) xFormers·SDPA·gradient checkpointing을 “중복 적용”하지 말기
속도 저하의 대표 원인 중 하나가 메모리 절약 옵션을 과하게 겹쳐서 오히려 느려지는 경우입니다.
대표적인 옵션들은 다음과 같습니다.
xFormers메모리 효율 어텐션- PyTorch
SDPA(scaled dot product attention) gradient_checkpointing
원칙:
- VRAM이 부족하면
gradient_checkpointing이 효과적이지만 속도는 느려집니다. xFormers또는SDPA는 VRAM과 속도를 둘 다 개선하는 편이지만, 환경별로 편차가 있습니다.
diffusers에서 SDPA 활성화 예시
import torch
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
)
pipe.to("cuda")
# PyTorch 2.x
pipe.enable_attention_slicing() # 필요 시
pipe.unet.set_attn_processor("sdpa")
권장 접근:
- 먼저
xFormers또는SDPA중 하나만 적용하고 속도/VRAM 측정 - 그래도 OOM이면
gradient_checkpointing을 추가 attention_slicing은 최후의 수단(대체로 속도 손해가 큼)
5) 옵티마이저를 AdamW 기본값으로 고집하지 않기
OOM의 숨은 주범은 옵티마이저 상태(state) 입니다. Adam 계열은 파라미터마다 모멘텀/분산 등 추가 텐서를 들고 있어 VRAM을 많이 잡아먹습니다.
LoRA는 “학습 파라미터 수가 적다”는 장점이 있지만, 설정에 따라(LoRA rank를 과도하게 키우거나, text encoder까지 넓게 학습) 옵티마이저 상태가 빠르게 늘 수 있습니다.
대안:
8-bit Adam(bitsandbytes)Adafactor(메모리 절약형)
8-bit Adam 사용 예시
pip install bitsandbytes
accelerate launch train_text_to_image_lora_sdxl.py \
--optimizer adamw8bit \
--learning_rate 1e-4 \
--mixed_precision fp16
체감 효과:
- VRAM이 빡빡한 12GB, 16GB 환경에서 특히 유의미
- 속도는 GPU/드라이버 조합에 따라 비슷하거나 약간 느릴 수 있으나, OOM으로 중단되는 것보단 훨씬 낫습니다.
6) 데이터 로더 병목을 잡아 “GPU가 놀지 않게” 만들기
OOM만큼 흔한 게 속도 저하입니다. 그런데 많은 경우 모델이 느린 게 아니라, 데이터 로딩/전처리가 병목이라 GPU가 대기합니다.
증상:
- GPU 사용률이 30~60%에서 왔다 갔다 함
- step time이 들쭉날쭉함
해결 포인트:
num_workers증가persistent_workers=Truepin_memory=True- 이미지 리사이즈/버킷팅을 사전 캐싱
PyTorch DataLoader 예시
from torch.utils.data import DataLoader
loader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
num_workers=8,
pin_memory=True,
persistent_workers=True,
prefetch_factor=4,
)
추가 팁:
- 학습 이미지가 네트워크 스토리지(NFS, S3 FUSE 등)에 있으면 지연이 커집니다. 가능하면 로컬 SSD로 내려받아 학습하세요.
- 이미지가 너무 크고 매 step마다 리사이즈하면 CPU가 병목이 됩니다. “학습용 해상도 버전”을 미리 만들어 두는 것이 가장 확실합니다.
비슷한 맥락으로, 장애가 나면 원인을 체계적으로 좁혀가는 운영 습관이 중요합니다. 쿠버네티스 환경에서 학습 워커가 재시작을 반복한다면 Kubernetes CrashLoopBackOff 원인 12가지와 진단처럼 로그/리소스/프로브 관점으로 원인을 분리해보는 방식이 그대로 적용됩니다.
7) 검증(샘플링)과 체크포인트 저장이 학습을 “주기적으로 멈추게” 한다
SDXL LoRA는 학습 자체도 무겁지만, 주기적으로 수행하는 작업들이 속도를 크게 갉아먹습니다.
- validation 이미지 생성(추론 파이프라인 로드, VAE 디코딩)
- 체크포인트 저장(디스크 I/O)
- 로그 이미지 저장(특히 PNG 다량 저장)
해결 전략:
- validation 주기를 늘리기
- 저장 포맷/빈도 최적화
- 가능하면 검증은 별도 프로세스에서(학습 중간 산출물만 읽기)
저장 주기 예시(개념)
# 예: 200 step마다 저장, 500 step마다 validation
--checkpointing_steps 200 \
--validation_steps 500
디스크 I/O가 느린 환경(공유 볼륨, 네트워크 디스크)에서는 체크포인트 저장이 “짧은 프리징”을 만들고, 이게 누적되면 체감 속도가 크게 떨어집니다. 특히 클러스터에서 ALB나 프록시 뒤로 상태를 노출하는 워크플로라면 타임아웃까지 이어질 수 있는데, 이런 류의 “느려져서 발생하는 장애”는 EKS ALB Ingress 502 target timeout 원인·해결 같은 글의 접근처럼 병목 지점을 먼저 분리하는 게 효과적입니다.
(보너스) OOM·속도저하를 동시에 줄이는 추천 조합
환경별로 다르지만, 많이 안정적으로 쓰는 조합을 정리하면 다음과 같습니다.
train_batch_size=1gradient_accumulation_steps로 유효 배치 맞추기mixed_precision=bf16(가능하면) 또는fp16- 어텐션 최적화는
xFormers또는SDPA중 하나 - VRAM이 부족하면
gradient_checkpointing추가 - 옵티마이저는
adamw8bit고려 - DataLoader
num_workers/persistent_workers/pin_memory튜닝 - validation/저장 주기 늘리기
마무리: “한 번에 하나씩” 바꾸고, 수치로 비교하기
SDXL LoRA 학습 튜닝은 정답이 고정돼 있지 않습니다. 같은 GPU라도 드라이버, PyTorch 버전, xFormers 빌드 여부, 데이터셋 형태에 따라 결과가 달라집니다.
가장 좋은 방법은 다음 루틴입니다.
- 기준 실행을 만들고(step time, VRAM 사용량, GPU utilization 기록)
- 위 7가지 중 하나만 바꿔서 재측정
- OOM이 사라지고 step time이 줄어드는 조합을 확정
만약 학습 파이프라인을 서비스 형태로 돌리며 외부 API 호출(예: 프롬프트 생성, 태깅, 메타데이터 보강)을 끼워 넣는다면, 느려질 때 재시도/타임아웃을 안전하게 설계하는 것도 중요합니다. 이때는 Python 데코레이터로 async 타임아웃·재시도 패턴의 패턴을 적용해 병목이 학습 전체를 멈추지 않게 만드는 방식이 도움이 됩니다.
원하시면 사용 중인 GPU 모델, VRAM 용량, 학습 스크립트(diffusers 또는 kohya-ss), 목표 해상도/배치, 현재 발생 로그(OOM stack trace 일부)를 기준으로 “당장 통과되는 설정 조합”을 구체적으로 잡아드릴게요.