Stable diffusion은 매우 좋은 성능을 띄지만, 문제점이 존재한다. 매우 어둡거나 매우 밝은 이미지를 생성하도록 시도해도 항상 평균값이 0.5에 가까운 이미지를 생성한다. (완전 검정색 이미지는 0, 완전 흰색 이미지는 1일때)
평균이 0.5에 가까운 이미지를 생성하는 제약 조건은 white나 black이 아닌 grey를 생성하는 문제, 빈 영역이 아닌 high-frequency texture (로고 등)를 생성하는 문제, 사물이 washed out되고 밝은 fog 영역이 어두운 영역을 상쇄하는 문제 등을 야기함. 이 문제들 중 일부는 후처리를 통해서 해결할 수 있겠지만, 근본적인 해결책이 되진 않음.
왜 이러한 문제가 발생할까? 훈련 데이터의 문제일까? 모델 아키텍쳐의 문제일까? 아니면 일반적인 Diffusion model의 한계일까?
SD를 실제로 Dreambooth와 같은 기법으로 fine-tuning하면 꽤 잘 작동한다. 특정 사람의 얼굴이나 특정 고양이를 잘 생성한다. 이는 수개 또는 수십개의 이미지와 수천번의 gradient update로 충분히 그 물체의 생김새를 배운다. 10000 step이 지나면 특정 이미지를 기억하기 시작한다.
하지만 하나의 단색의 검은색 이미지로 파인튜닝 하면 아래와 같은 결과가 나옴.
SD 모델은 매우 어둡거나, 매우 밝은 이미지를 생성할 수 도 없고, 학습할 수도 없는 것 처럼 보인다. 문제를 어떻게 해결할 수 있을까?? 원인은 뭘까?
Diffusion model이 reverse 과정을 어떻게 학습하는 지를 살펴보자. Diffusion model의 reverse는 “적은양의 iid(independently, identically distributed) Gaussian noise를 점진적으로 더하는” forward의 역과정이다.
즉 latent space에서 각 픽셀은 각 step에서 random noise를 받는다. diffusion model은 이러한 forward 단계를 여러번 거친후, 원래 이미지로 궤적을 따라 돌아가는 방향을 찾는것을 학습한다. Inference 할때에는 원래 이미지로 돌아갈 수 있는 모델이 주어지면, pure 한 noise로부터 noising process를 reverse하며 새로운 이미지를 얻게 된다.
문제는 forward process는 원본이미지를 완전히 지우지 않기 때문에, pure한 noise로부터 완전한 실제 분포로 정확히 돌아오지는 못한다. latent noise가 마지막으로 파괴하는 정보들은 reverse 과정에 의해서 가장 약하게 변경된다. 이 정보들은 reverse process 시작할때 남아있다.
같은 seed를 사용하고 다른 프롬프트를 사용하면 전체적인 수준에서는 서로 비슷한 이미지를 생성하지만, local한 small-scale의 pattern 수준에서는 그렇지 않다. Diffusion model은 확실히 long wavelength feature를 변경하는 것을 어려워한다. longest wavelength feature는 이미지 전체의 평균이고, 이는 latent noise의 독립적인 샘플 간에 변할 가능성이 가장 적다.
이 문제는 대상 객체의 차원이 높을수록 더욱 심해진다. 왜냐하면 독립적인 noise sample set의 표준편차가 1/N으로 scaling 되기 때문이다.
Stable diffusion 모델은 3X512X512의 이미지를 VAE에 넣어 4X64X64의 latent vector로 변환하여 처리하는데, 이 때의 dimension은 4X64X64=16384이다.
longest wavelength 관점에서 noise를 주는 것을 생각하면, 분산이 N만큼 scaling 된다. -> 다시 말해서 ,mean이라고 할 수 있는 longest wavelength는 가장 shortest한 wavelength라고 할 수 있는 pixel에 비해서 N배나 더 분산이 작다고 볼 수 있다. (channel을 고려하지 않았을때 SD에서 64배 더 표준편차가 적음)
즉 이론상 의 표준오차(통계량의 표준편차)는 longest wavelength에서 1/N 만큼 scaling된다. 때문에 longest-wavelength는 shortest-wavelength 에 비해 N 만큼 덜 변한다.
noise = torch.randn_like(latents) + 0.1 * torch.randn(latents.shape[0], latents.shape[1], 1, 1)