- Published on
Stable Diffusion LoRA 학습 NaN 폭주 잡는법
- Authors
- Name
- 스타차일드
- https://x.com/ETFBITX
서로 잘 맞던 설정인데도 LoRA 학습이 어느 순간부터 loss=nan으로 터지거나, 특정 스텝 이후 손실이 무한대로 치솟는 문제는 생각보다 흔합니다. 특히 Stable Diffusion 계열(UNet + Text Encoder)에서 mixed precision, 고해상도 버킷, 과한 학습률, 잘못된 데이터 전처리(깨진 이미지/비정상 캡션) 같은 요소가 겹치면 수치가 쉽게 불안정해집니다.
이 글은 “왜 NaN이 생기는지”를 원인별로 분해하고, 실제로 많이 쓰는 kohya-ss/sd-scripts 및 diffusers 기반 학습에서 바로 적용 가능한 안정화 레시피를 제공합니다.
NaN 폭주를 이해하는 최소 수치 상식
딥러닝에서 NaN은 대개 다음 경로로 생깁니다.
- 오버플로우: FP16/BF16에서 표현 범위를 넘어가
inf가 생기고, 이후 연산에서inf - inf,0 * inf같은 조합으로 NaN이 전파됩니다. - 언더플로우/정밀도 손실: 아주 작은 값이 0으로 깎이면서 분모가 0에 가까워지는 연산(정규화, 분산 계산 등)에서 불안정해집니다.
- 그래디언트 폭주: 학습률이 크거나 배치가 작아 노이즈가 큰 경우, 특정 레이어(특히 attention/conv)에서 gradient norm이 급증합니다.
- 입력 이상치: 깨진 이미지, 알파 채널/색공간 문제, 극단적으로 긴 캡션/토큰, NaN이 포함된 텐서가 들어오면 첫 스텝부터 NaN이 날 수 있습니다.
핵심은 “NaN이 언제 발생하는지”를 좁혀서, 데이터 문제인지(초반부터), 학습률/정밀도 문제인지(몇백~몇천 스텝 후) 구분하는 것입니다.
1) 가장 먼저 할 일: 재현 가능한 관측 지점 만들기
NaN을 잡는 과정은 디버깅입니다. 관측이 없으면 시행착오만 늘어납니다.
체크포인트 1: NaN 발생 스텝과 직전 로그 확보
- 학습 로그에
step,loss,lr를 반드시 남기세요. - 가능하면 gradient norm도 함께 기록하세요.
diffusers/PyTorch 공통으로 쓸 수 있는 간단한 NaN 가드 예시는 아래처럼 만들 수 있습니다.
import torch
def has_nan(t: torch.Tensor) -> bool:
return torch.isnan(t).any().item()
def has_inf(t: torch.Tensor) -> bool:
return torch.isinf(t).any().item()
@torch.no_grad()
def check_model_params(model):
for n, p in model.named_parameters():
if p is None:
continue
if has_nan(p) or has_inf(p):
raise RuntimeError(f"param exploded: {n}")
학습 루프에서 optimizer.step() 직후나 loss.backward() 직후에 호출해 “어느 시점부터 파라미터가 망가지는지”를 고정하세요.
체크포인트 2: 데이터 1배치 고정 테스트
데이터가 원인인지 가장 빨리 확인하는 방법은 단일 배치 overfit 입니다.
- 데이터 1~2장만으로 학습
- 학습률 낮게
loss가 계속 감소하는지 확인
여기서도 NaN이 나면 데이터/전처리/정밀도 설정 문제가 유력합니다.
2) 데이터 원인: 깨진 이미지, 알파 채널, 비정상 캡션
LoRA 학습의 NaN은 의외로 “모델”이 아니라 “데이터”에서 시작하는 경우가 많습니다.
흔한 데이터 문제 패턴
- 손상된 파일(부분 다운로드, 0바이트, 헤더 깨짐)
- PNG 알파 채널 처리 미흡(
RGBA를 그대로 텐서화) - CMYK/JPEG 색공간 변환 문제
- 해상도 버킷팅 과정에서 0 또는 음수 크기(커스텀 스크립트에서 발생)
- 캡션에 제어문자/너무 긴 토큰(특히 자동 태깅 결과가 폭주)
이미지 무결성 빠르게 검사하기
학습 전에 데이터셋을 한 번 스캔하면, “몇 시간 학습하다 터지는” 상황을 줄일 수 있습니다.
from PIL import Image
import os
def scan_images(root):
bad = []
for dirpath, _, filenames in os.walk(root):
for fn in filenames:
if not fn.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
continue
path = os.path.join(dirpath, fn)
try:
with Image.open(path) as im:
im.verify() # 무결성 검사
with Image.open(path) as im:
im = im.convert("RGB")
w, h = im.size
if w <= 0 or h <= 0:
bad.append((path, "invalid size"))
except Exception as e:
bad.append((path, str(e)))
return bad
bad = scan_images("./train")
print("bad count:", len(bad))
for p, r in bad[:20]:
print(p, r)
verify()는 디코딩 전체를 보장하진 않지만, 깨진 파일을 상당수 걸러냅니다.- 학습 파이프라인에서 최종적으로는
RGB로 통일하는 게 안전합니다.
캡션 길이 제한과 정규화
- 캡션이 너무 길면 attention 쪽에서 수치가 불안정해질 수 있습니다.
- 토큰 길이를 제한하거나, 불필요한 반복 태그를 제거하세요.
kohya 계열에서는 캡션 드롭아웃/셔플 옵션도 NaN 자체를 “직접” 막지는 않지만, 특정 토큰 조합에 과적합되며 폭주하는 케이스를 줄여줍니다.
3) mixed precision(특히 FP16)에서 폭주: BF16 또는 FP32로 격리
NaN의 1순위는 FP16 오버플로우입니다.
우선순위 높은 처방
- GPU가 지원하면 BF16로 변경
- 안 되면 UNet만 FP16, 나머지 FP32 등 부분 격리
- 그래도 안 되면 전체 FP32로 원인 분리
accelerate 사용 시 개념적으로는 다음과 같습니다.
from accelerate import Accelerator
accelerator = Accelerator(mixed_precision="bf16") # 또는 "fp16"/"no"
- BF16은 FP16보다 표현 범위가 넓어 오버플로우에 강합니다.
- 학습이 느려지더라도 “NaN 재현을 끊는” 용도로 FP32는 매우 유효합니다.
loss scaling 설정 점검
FP16에서 GradScaler가 자동 loss scaling을 하는데, 학습이 불안정하면 스케일이 계속 흔들리며 NaN으로 이어질 수 있습니다.
- 자동 스케일링이 불안정하면 BF16로 옮기는 것이 보통 더 빠른 해결입니다.
4) 학습률과 스케줄러: LoRA는 생각보다 작은 LR이 안전
LoRA는 전체 모델 파인튜닝보다 적은 파라미터를 업데이트하지만, 그만큼 학습률이 과하면 특정 레이어가 빠르게 발산할 수 있습니다.
안전한 시작점(경험칙)
- UNet LoRA:
1e-4에서 시작, 불안정하면5e-5또는1e-5 - Text Encoder LoRA: UNet보다 더 낮게(
5e-5또는1e-5)
특히 아래 조합은 NaN을 자주 만듭니다.
- 작은 데이터셋 + 큰 LR
- batch가 너무 작음 + gradient accumulation 없이 진행
- warmup 없이 바로 큰 LR
warmup은 사실상 필수에 가깝다
초반 수십~수백 스텝에서 그래디언트가 요동치기 쉬우므로 warmup이 도움이 됩니다.
diffusers에서 스케줄러를 쓰는 예시:
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=200,
num_training_steps=max_train_steps,
)
5) Gradient clipping으로 폭주 상한선 만들기
NaN은 “한 번의 폭주”로 시작해 전체 파라미터를 오염시키는 경우가 많습니다. gradient clipping은 그 한 번을 막는 안전장치입니다.
import torch
max_norm = 1.0
torch.nn.utils.clip_grad_norm_(trainable_params, max_norm)
- 보통
1.0또는0.5가 무난합니다. - clipping을 켜도 계속 NaN이면, 근본 원인이 데이터/정밀도/LR일 가능성이 큽니다.
6) Optimizer 선택: AdamW 8-bit, Adafactor, Lion의 함정
옵티마이저 자체가 NaN을 “만들기”보다는, 특정 조합에서 불안정성을 증폭시키는 경우가 있습니다.
AdamW는 가장 예측 가능8-bit AdamW(bitsandbytes)는 VRAM 절약에 좋지만, 환경/버전에 따라 수치 이슈가 보고됩니다Adafactor는 설정에 따라 안정적일 수도 있지만, LoRA에서는 튜닝 포인트가 늘어납니다
문제 해결 우선순위는 다음이 좋습니다.
AdamW로 단순화- NaN이 없어지면 그 다음에 8-bit로 최적화
7) 해상도/버킷/배치: “큰 해상도 + 작은 배치” 조합 주의
고해상도 버킷을 쓰면 attention map이 커지고, 메모리 압박이 커지면서 gradient checkpointing, mixed precision 등 여러 요소가 동시에 켜집니다. 이때 NaN이 더 잘 발생합니다.
권장 디버깅 절차:
- 해상도를
512고정으로 낮춤 - 버킷팅 끔
- batch를 가능한 범위에서 키우거나 accumulation으로 유효 배치 증가
- 안정화 후 원래 해상도로 복귀
이 과정은 “원인 분리”에 매우 효과적입니다.
8) Text Encoder까지 학습할 때 NaN이 늘어나는 이유
Text Encoder는 작은 LR에도 민감합니다. 특히 CLIP 계열은 토큰/포지션 임베딩과 layer norm을 포함해 수치적으로 예민한 편입니다.
- UNet만 학습했을 때는 괜찮다가, Text Encoder까지 켜는 순간 터지면
- Text Encoder LR을 UNet의
1/5~1/10로 낮추고 - warmup을 늘리며
- 가능하면 BF16을 사용하세요.
- Text Encoder LR을 UNet의
또는 “Text Encoder는 학습하지 않고, 캡션 품질을 올리는 방식”이 결과적으로 더 좋은 경우도 많습니다.
9) NaN이 났을 때 즉시 할 수 있는 응급 처치
이미 학습이 진행 중인데 NaN이 발생했다면, 아래 순서로 대응합니다.
- 즉시 중단: NaN 이후 스텝은 대부분 가중치를 오염시킵니다.
- 마지막 정상 체크포인트로 롤백
- 다음 중 하나를 적용하고 재시작
- LR을
1/2또는1/10로 감소 - mixed precision을
bf16또는no로 변경 - gradient clipping 활성화
- 문제 배치가 데이터라면 해당 파일 제거
- LR을
학습 파이프라인을 운영 관점에서 보면 “폭주를 감지하고 자동 중단”이 중요합니다. 이건 장애 대응과 비슷한데, 원인 분석 체크리스트를 갖춰두면 재발 방지가 쉬워집니다. 운영 장애를 체크리스트로 줄이는 방식은 systemd 서비스 무한 재시작 - Exit code 203 해결 같은 글의 접근과도 결이 같습니다.
10) kohya-ss 기준: 안정화에 자주 쓰는 설정 조합
환경마다 다르지만, “NaN을 우선 막는” 쪽으로 보수적으로 잡은 예시는 아래와 같습니다. 부등호가 들어갈 수 있는 표현은 인라인 코드로 처리합니다.
accelerate launch train_network.py \
--pretrained_model_name_or_path="./sd15.safetensors" \
--train_data_dir="./train" \
--output_dir="./out" \
--output_name="lora_stable" \
--network_module=networks.lora \
--network_dim=16 \
--network_alpha=16 \
--resolution=512 \
--enable_bucket=false \
--train_batch_size=2 \
--gradient_accumulation_steps=4 \
--max_train_steps=4000 \
--learning_rate=1e-4 \
--unet_lr=1e-4 \
--text_encoder_lr=1e-5 \
--lr_scheduler=cosine \
--lr_warmup_steps=200 \
--mixed_precision=bf16 \
--clip_skip=2 \
--max_grad_norm=1.0 \
--save_every_n_steps=500
포인트:
- 버킷/고해상도/공격적 LR 같은 변수를 일단 줄여 “안정 상태”를 만든 뒤 확장
--max_grad_norm은 폭주 방지에 체감이 큼bf16가능하면 우선 적용
11) diffusers 학습 루프에서 NaN을 조기에 잡는 패턴
diffusers로 커스텀 학습을 짜는 경우, 아래처럼 “손실이 NaN이면 즉시 중단”을 걸어두면 시간 낭비를 크게 줄일 수 있습니다.
import math
import torch
loss = ...
if not torch.isfinite(loss):
raise RuntimeError(f"non-finite loss detected: {loss.item()}")
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(trainable_params, 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
torch.isfinite로nan과inf를 동시에 차단- clipping은
sync_gradients시점에만 적용(gradient accumulation 사용 시)
12) 원인별 빠른 체크리스트(실전용)
A. 첫 스텝부터 NaN
- 데이터 깨짐/디코딩 실패 여부 스캔
- 이미지
RGB변환 강제 - 캡션 파일 인코딩/제어문자 확인
- mixed precision을
no로 바꿔 재현되는지 확인
B. 수백~수천 스텝 후 NaN
- LR 과대 가능성:
lr을1/2~1/10 - warmup 추가/증가
- gradient clipping 적용
- 해상도/버킷 끄고 재현되는지 확인
C. 특정 배치에서만 NaN(간헐적)
- 데이터 로더에서 샘플 인덱스/파일명 로깅
- 문제 파일 격리 후 제거
- augmentation 파이프라인(랜덤 크롭/리사이즈)에서 0 크기 발생 여부 확인
데이터에서 “특정 조합이 행 폭증을 만들듯” 학습에서도 “특정 샘플이 폭주를 만든다”는 형태가 자주 나옵니다. 이런 유형의 진단 접근은 Pandas merge에서 행 폭증? 중복키 진단법처럼 원인을 좁히는 방식이 그대로 통합니다.
13) 결론: NaN을 ‘없애는’ 가장 현실적인 순서
- 데이터 무결성 검사 + RGB 통일
- mixed precision은 가능하면 BF16
- LR을 보수적으로(특히 Text Encoder는 더 낮게) + warmup
- gradient clipping으로 단발 폭주 차단
- 해상도/버킷/옵티마이저 최적화는 안정화 이후
NaN은 한 가지 원인만으로 생기기보다, “작은 불안정성”이 여러 개 겹쳐서 임계점을 넘을 때 터지는 경우가 많습니다. 따라서 설정을 한 번에 많이 바꾸기보다, 한 번에 한 변수씩 바꾸며 재현성을 유지하는 게 가장 빠른 해결책입니다.