- Published on
Stable Diffusion LoRA 학습 붕괴 디버깅 - fp16·bf16·EMA
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서로 다른 LoRA 설정을 돌리다 보면 어느 순간부터 결과물이 전부 뭉개지거나, 학습 중 손실이 nan 으로 바뀌거나, 특정 스텝 이후 갑자기 과포화된 색감과 노이즈만 남는 “학습 붕괴”를 겪습니다. 문제는 대개 데이터가 아니라 정밀도(fp16·bf16), 옵티마이저 상태, EMA 적용, 스케일링(gradient/loss) 같은 수치 안정성에서 시작됩니다.
이 글은 Stable Diffusion LoRA 학습에서 자주 발생하는 붕괴 패턴을 재현 가능한 관찰 지표로 분해하고, fp16·bf16·EMA를 중심으로 어디를 어떻게 의심해야 하는지 디버깅 절차를 제공합니다.
참고로 OOM이나 메모리 압박이 원인인 경우도 많습니다. 그 경우는 Transformers 로컬 LLM CUDA OOM, 4bit·KV캐시 최적화 글의 “증상 기반 진단” 방식이 그대로 도움이 됩니다.
LoRA 학습 붕괴의 전형적인 증상 분류
먼저 “붕괴”를 하나로 뭉뚱그리면 해결이 늦어집니다. 아래처럼 나눠 보면 원인 후보가 빠르게 줄어듭니다.
1) 손실이 갑자기 nan 또는 inf
- 특정 스텝 이후 손실이
nan으로 고정 - 그 직전부터 gradient norm이 비정상적으로 커짐
- 체크포인트를 로드해도 같은 지점에서 재발
가장 흔한 원인
- fp16에서 overflow, 또는 잘못된 loss scaling
- Adam 계열 옵티마이저 상태가 fp16으로 유지됨
- EMA가 fp16으로 누적되며 오차가 폭발
2) 손실은 내려가는데 샘플이 점점 뭉개짐
- 학습 중간까지는 괜찮다가, 후반부에 인물 얼굴이 녹거나 질감이 단순화
- 특정 토큰이나 스타일만 과도하게 강화
가장 흔한 원인
- 과한 학습률 또는 rank 대비 과한 alpha
- prior preservation 미적용 또는 캡션 편향
- EMA를 잘못 적용해서 “평균 모델”이 오히려 나쁜 방향으로 수렴
3) 특정 프롬프트에서만 갑자기 망가짐
- 일반 프롬프트는 괜찮은데 트리거 토큰 포함 시만 붕괴
가장 흔한 원인
- 텍스트 인코더까지 학습하면서 fp16 오차가 누적
- 데이터셋 캡션이 트리거 토큰에 과도하게 결합
디버깅의 핵심: “수치 안정성”을 가시화하기
학습이 무너질 때 로그에 손실만 찍고 있으면 원인을 못 잡습니다. 최소한 아래 지표를 함께 기록하세요.
lossgrad_norm또는 gradient clipping 전후 normlr- AMP 스케일러 값(사용 시)
- 파라미터/그라디언트에
nan존재 여부
PyTorch에서 NaN 감지 훅
아래는 학습 루프에 얹어두기 좋은 최소 훅입니다.
import torch
def has_nan(t):
return torch.isnan(t).any().item()
def has_inf(t):
return torch.isinf(t).any().item()
@torch.no_grad()
def check_model_numerics(model):
for name, p in model.named_parameters():
if p is None:
continue
if has_nan(p) or has_inf(p):
return False, f"param {name} has nan/inf"
return True, "ok"
@torch.no_grad()
def check_grads_numerics(model):
for name, p in model.named_parameters():
if p.grad is None:
continue
if has_nan(p.grad) or has_inf(p.grad):
return False, f"grad {name} has nan/inf"
return True, "ok"
학습 중 특정 스텝에서 grad 쪽이 먼저 터지는지, param 이 먼저 터지는지에 따라 원인이 달라집니다.
grad먼저nan이면: loss scaling, fp16 overflow, 입력 스케일 문제 가능성param먼저nan이면: 옵티마이저 업데이트(특히 Adam state), EMA 업데이트 로직 문제 가능성
fp16 vs bf16: “빠름”보다 “안정성”이 우선인 구간
fp16이 무너지는 전형적인 조건
fp16은 표현 가능한 지수 범위가 좁아서, 아래 조건에서 overflow가 자주 발생합니다.
- 학습률이 높음
- gradient accumulation을 크게 잡아 한 번의 업데이트가 큼
- UNet만이 아니라 텍스트 인코더까지 학습
- LoRA rank가 높고 alpha가 커서 업데이트가 공격적
- 노이즈 스케줄 또는 loss 가중이 특정 구간에서 급격
이때 손실이 nan 으로 바뀌는 지점은 “우연”이 아니라, 특정 배치에서 값이 폭발한 결과일 확률이 큽니다.
bf16이 상대적으로 안전한 이유
bf16은 fp16보다 mantissa 정밀도는 낮지만, 지수 범위가 fp32에 가깝습니다. 즉 overflow에 강합니다. LoRA 학습 붕괴를 잡는 가장 빠른 방법 중 하나가 fp16에서 bf16으로 바꾸는 것입니다.
다만 모든 GPU가 bf16을 잘 지원하는 것은 아닙니다. Ampere 이후(예: A100, RTX 30 일부, RTX 40)에서 일반적으로 유리합니다.
권장 우선순위
- 가능하면 bf16
- bf16이 안 되면 fp16 + 안정화 장치(clip, 낮은 lr, scaler 조정)
AMP와 loss scaling: “자동”이 만능은 아니다
PyTorch AMP를 쓰면 보통 GradScaler 가 loss scaling을 자동으로 조절합니다. 그런데 LoRA 학습 파이프라인에 따라 다음 문제가 생깁니다.
- scaler가 과도하게 커져서 overflow 유발
- scaler가 너무 자주 줄어들어 실질 학습이 정체
- gradient accumulation과 결합 시 스케일링 타이밍이 꼬임
AMP 사용 시 체크 포인트
- scaler 값이 시간이 지날수록 단조 증가하는지
- overflow 감지로 인해 scaler가 계속 내려가는지
학습 로그에 scaler를 남기세요.
scaler = torch.cuda.amp.GradScaler(enabled=use_fp16)
# training step
with torch.cuda.amp.autocast(enabled=use_fp16):
loss = compute_loss(...)
scaler.scale(loss).backward()
# optional: gradient clipping은 unscale 이후
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
scaler.step(optimizer)
scaler.update()
current_scale = scaler.get_scale()
여기서 중요한 포인트는 clip은 unscale_ 이후에 해야 한다는 점입니다. 순서가 바뀌면 clip이 사실상 무의미해지거나, 반대로 과도하게 작동할 수 있습니다.
옵티마이저 상태(dtype)와 8bit Adam의 함정
LoRA는 학습 파라미터 수가 적어서 “옵티마이저 상태가 문제일까”라는 의심을 잘 안 합니다. 하지만 fp16 붕괴의 많은 케이스가 Adam의 모멘트 추정치(state)가 낮은 정밀도로 유지되며 발생합니다.
점검 포인트
- 옵티마이저 state가 fp32인지
- 8bit Adam을 사용할 때 특정 GPU/드라이버 조합에서 수치 문제가 없는지
일반적으로 안정성을 최우선으로 둘 때는 다음이 무난합니다.
- bf16 + AdamW(fp32 state)
- fp16이면 특히 Adam state가 fp32인지 확인
프레임워크에 따라 다르지만, “모델은 fp16인데 옵티마이저는 fp32” 조합이 흔히 가장 안정적입니다.
EMA: 켜면 좋아질 수도, 망가질 수도 있다
EMA(Exponential Moving Average)는 학습 파라미터의 이동 평균을 유지해 샘플 품질을 안정화하는 데 도움을 줍니다. 하지만 LoRA에서는 다음 케이스에서 오히려 붕괴를 유발하거나, “학습은 정상인데 결과만 망가진” 것처럼 보이게 합니다.
1) EMA를 fp16으로 누적
EMA는 누적 평균이라 작은 오차가 계속 쌓입니다. EMA 파라미터를 fp16으로 유지하면, 미세한 업데이트가 양자화되어 사라지거나 반대로 특정 방향으로 치우칠 수 있습니다.
권장
- EMA 파라미터는 fp32로 유지
- 모델 forward는 bf16 또는 fp16을 쓰더라도 EMA 저장은 fp32
2) EMA 적용 대상이 잘못됨
LoRA 학습에서 EMA를 “전체 UNet”에 걸어버리면, 실제로 업데이트되는 건 LoRA 레이어뿐인데도 EMA가 전체 파라미터를 복사하고 평균내며 불필요한 오버헤드와 버그 포인트가 생깁니다.
권장
- EMA 대상은 “학습되는 파라미터만”
- 즉 LoRA 파라미터(및 학습 시킨 텍스트 인코더 일부)로 제한
3) 평가 시점에 EMA와 원본 가중치가 섞임
샘플링 시점에 다음 실수가 흔합니다.
- 일부 모듈은 EMA, 일부는 non-EMA
- 저장은 EMA인데 로드는 non-EMA
- UNet은 EMA인데 text encoder는 non-EMA
이 경우 “학습은 정상, 샘플만 붕괴”처럼 보입니다.
EMA 구현 예시(학습 파라미터만, fp32 유지)
import copy
import torch
class EMA:
def __init__(self, params, decay=0.999):
self.decay = decay
self.shadow = {}
for name, p in params:
if not p.requires_grad:
continue
self.shadow[name] = p.detach().float().clone()
@torch.no_grad()
def update(self, params):
for name, p in params:
if name not in self.shadow:
continue
new = p.detach().float()
old = self.shadow[name]
self.shadow[name] = old * self.decay + new * (1.0 - self.decay)
@torch.no_grad()
def apply_to(self, params):
for name, p in params:
if name in self.shadow:
p.copy_(self.shadow[name].to(dtype=p.dtype, device=p.device))
# usage
trainable = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
ema = EMA(trainable, decay=0.999)
# after optimizer step
ema.update(trainable)
핵심은 shadow 를 fp32로 들고 가는 것입니다.
“붕괴”를 막는 설정 체크리스트
여기서는 원인별로 바로 적용 가능한 처방을 정리합니다.
A) 손실이 nan 으로 튄다
- bf16 사용 가능하면 bf16로 전환
- 학습률을 2배에서 10배까지 낮춰 재현 여부 확인
- gradient clipping
max_norm=1.0또는0.5적용 - AMP 사용 시
unscale_이후 clip 순서 확인 - 텍스트 인코더 학습을 끄고 UNet LoRA만으로 재현 테스트
- EMA를 잠시 끄고 동일 설정으로 비교
B) 샘플이 후반에 뭉개진다(과적합 또는 과한 업데이트)
- LoRA rank를 낮추거나 alpha를 낮춤
- 학습 스텝을 줄이고, 중간 체크포인트에서 샘플 품질 비교
- 데이터 캡션을 정리하고, 트리거 토큰의 사용 빈도를 낮춤
- EMA를 켰다면 “EMA 결과”와 “non-EMA 결과”를 동시에 저장해 비교
C) 특정 프롬프트에서만 망가진다
- 트리거 토큰을 캡션에서 분리해 과결합을 줄임
- 텍스트 인코더 학습 시 bf16 권장
- validation 프롬프트를 여러 그룹으로 고정해 회귀 테스트
이런 “원인별 체크리스트” 접근은 장애 대응에서도 유사합니다. 예를 들어 GCP Cloud Run 503/504 원인별 해결 - 타임아웃·동시성 글처럼, 증상을 분류하고 관측 지표를 늘리면 해결 속도가 급격히 빨라집니다.
diffusers 기준: 안전한 기본 조합 예시
아래는 “안정성 우선” 설정의 방향성을 보여주는 예시입니다. 사용하는 스크립트에 맞게 옵션 이름은 조정하세요.
accelerate launch train_text_to_image_lora.py \
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
--train_data_dir="./data" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--lr_scheduler="cosine" \
--lr_warmup_steps=200 \
--max_train_steps=3000 \
--checkpointing_steps=500 \
--validation_prompt="a portrait photo" \
--seed=42 \
--mixed_precision="bf16" \
--enable_xformers_memory_efficient_attention \
--max_grad_norm=1.0
fp16을 써야 한다면 다음을 추가로 고려합니다.
--learning_rate를 더 낮게--max_grad_norm을 더 보수적으로- EMA를 쓰면 shadow를 fp32로 유지
“한 번에 다 바꾸지 말고” 최소 실험 단위로 쪼개기
붕괴 디버깅에서 가장 큰 실수는 설정을 여러 개 동시에 바꾸는 것입니다. 아래 순서로 최소 실험 단위를 유지하면 원인을 빠르게 좁힐 수 있습니다.
- 동일 데이터, 동일 seed, 동일 validation 프롬프트 고정
- 정밀도만 변경: fp16
→bf16 - EMA만 on/off
- lr만 2배 단위로 조정
- 텍스트 인코더 학습 on/off
- 옵티마이저(AdamW
↔8bit Adam) 변경
여기서 화살표 기호가 포함된 텍스트는 MDX에서 안전하게 → 같은 엔티티로 표기하는 습관을 들이면 빌드 에러도 함께 예방됩니다.
결론: 붕괴는 “데이터 탓”이기 전에 “정밀도 탓”일 때가 많다
Stable Diffusion LoRA 학습 붕괴는 상당수 케이스에서 fp16의 overflow, AMP 스케일링, 옵티마이저 상태 정밀도, EMA 누적 방식 같은 “수치” 문제로 설명됩니다. 특히 다음 3가지만 지켜도 체감적으로 붕괴 빈도가 크게 줄어듭니다.
- 가능하면 bf16 사용
- gradient clipping과 AMP 순서(
unscale_이후 clip) 준수 - EMA를 쓴다면 fp32 shadow로, 적용 대상을 학습 파라미터로 제한
마지막으로, 문제를 한 번에 해결하려고 설정을 마구 섞기보다, 관측 지표를 늘리고 원인 후보를 하나씩 제거하는 방식이 가장 빠릅니다. 이는 데이터베이스 성능 문제에서 원인 좁히는 것과도 유사합니다. 예를 들어 MongoDB 느린 집계 파이프라인 $lookup 최적화 8단계 처럼 “측정 → 가설 → 최소 변경” 루프를 돌리면 LoRA 학습도 안정적으로 수렴 경로를 찾을 수 있습니다.