One-step Diffusion with Distribution Matching Distillation

Evergyu·2024년 4월 20일

논문리뷰

목록 보기
3/6
post-thumbnail

https://tianweiy.github.io/dmd/

ABSTRACT

우리는 2개의 스코어 펑션으로 표현되는 KL divergence를 최소화하는 one-step image generator를 소개한다. score function들은 두 개의 디퓨전 모델을 학습하여 각각 학습된다. ImageNet 데이터에 대해 2.56의 FID를 달성하였고 COCO-30k 데이터에 대해 11.49FID를 달성하였다.

1. INTRODUCTION

Diffusion models 들은 안정적이지만 GAN과 VAE와는 다르게 많은 시간을 필요로 한다.

이전에는 샘플링 가속화를 위해 Single student 방식으로 진행하였고 최근에는 여러 과정을 통합하여 Fully하게 process하지 않으려는 연구들이 있었다.
그런데 그렇게 과정을 통합하여 스텝을 줄였을 때 성능이 original 보다 떨어지는 부작용이 있었다.

우리는 생성 이미지와 원본 이미지의 관련성을 강조하는 것이 아닌 origianl model과 student generation이 구분할 수 없는 점을 강조하였다. 고차원적으로 생각하면 GAN과 GMMN에서 text-to-image task를 수행하는 것에서 아이디어를 얻었다.
이 work에서 우리는 대용량 text-to-image data로 사전학습된 모델을 합성 이미지가 더욱 score function으로 가짜 이미지에 대해 더 가짜로 해석하고 실제 이미지에 대해 더 realism하게 해석하게 하기 위해 실제 이미지와 distilled generator로 생성한 가짜 이미지를 fine tuning하였다. 마지막으로 생성기에 대하여 기울기 업데이트 규칙을 둘(실제 이미지를 학습한 score function, 가짜 이미지를 학습한 score function) 사이의 nudge로 생성하도록 한다.

더 나아가 우리는 적당한 수의 샘플링 결과를 사전 계산하고 간단한 회귀 손실을 적용하는 것이 효과적인 정규화가 되는 것을 발견하였다. 이 방법은 1. 실제와 가짜를 구분하는 diffusion model을 모델링하고 2. simple regression loss를 사용하여 multi-step diffusion output을 유도하면 one-step generative model을 사용할 수 있다.

Diffusion model

Diffusion model은 여러 domain에서 좋은 성과를 보였다. 이 모델은 noise를 reverse process를 통해 coherent structure로 변환한다. 하지만 computational cost가 많이 필요하다는 단점 때문에 실제 세계에서는 적용하기 힘들다. 근데 우린 가능

Diffusion Acceleration

추론 과정을 가속화 하는 것이 key 였다.
1. advances fast diffusion samplers: 사전학습 diffusion model을 통해 샘플링 과정을 줄이는 방법이 있다. 하지만 성능이 줄어드는 단점이 있었다.
2. diffusion distillationing speed

knowledge distillation: student model이 multi-step output을 학습하여 싱글 스텝으로 결과를 내는 방법. loss function 구현에 큰 비용이 드는 단점이 있다.

Progressive Distillation: student model이 학습해야 할 step을 나눠서 학습한다.

Consistency distilation: student 모델이 own output을 다른 timestep으로 결과를 내도록 학습한다.

우리는 simple approach를 보여준다.

Distribution Matching

학습 목적에 따라 noise injection이나 token masking같은 방법이 아닌 타겟 샘플과 생성 샘플의 distribution을 비교하는 시도도 있다. (GAN, GMMD)
GAN은 놀라운 성능을 보여주었지만 text-guided에는 안정적인 성능을 보이지 못해 덜 사용되었다.

최근, score-based model과 distribution matching을 하는 방식이 시도되었다.

Variational Score Distillation: 사전학습된 TTI diffusion model을 distribution matching의 loss로 사용하였다. VSD는 unpaired setting에서 대용량 사전학습 모델을 사용할 수 있었기 때문에 text-conditioned 3D synthesis에서 놀라운 성과를 보였다

우리는 VSD를 확장하여 모델링을 하였고 GAN의 image translation에서의 성공을 동기로 안정적인 regression loss를 학습하도록 하는 모델링에 성공하였다.

이는 GAN과 diffusion을 합친 것과는 결과가 다르다. 우리는 이 방법을 통해 text-to-image task에서 SOTA의 성능을 얻었다.

3. Distribution matching distillation

우리의 목표는 one-step으로 high-quality image를 생성하는 것이다. 우리는 distilled model의 생성 결과를 fake, 반대를 real로 구분하여 학습한다.
두 개의 loss를 사용함

  1. distribution matching objective: 두 score function의 차이를 통해 gradient를 update하기 위한 loss(3.2 절)

  2. regression loss: generator가 noise-image pair를 통해 생성하는 구조를 가진 loss (3.3 절)

마지막으로, classifire-free guidance 설명(3.4절)

Untitled

3.1 Pretrained based model and One-step generator

μbase(xt,t)\mu_{base}(x_t,t): T=1000으로 학습한 Gausian diffusion process로 x0x_0에서 prealp_{real}을 생성하는 모델

EDM and Stable Diffusion 의 pretrained model 사용

One-step generator(GθG_{\theta})

time conditioning이 없는 diffusion denoiser 기반

Gθ=μbase(z,T1)G_{\theta}=\mu_{base}(z,T-1)

3.2 Distribution Matching Loss

생성이미지와 가짜 이미지 사이의 KL divergence를 loss로 사용하고 최소화 하려함

이 loss를 추정하기 위해 확률 밀도를 계산하는 것은 어려운 일이기에, 기울기 하강에 의해 생성기를 훈련시키기 위해 θ\theta에 대한 기울기가 필요함

Gradient update using approximate scores

p= produced result / s=score - Eq2

Untitled

Untitled

sreals_{real}prealp_{real}을 향해 움직이고 (a), sfakes_{fake}sreals_{real}을 퍼지게 한다 (b). - 1차목표

적합하게 떨어지게 하는 것(+regression) - 2차 목표

Untitled

이 loss를 계산하는 것은 두 가지 문제가 있다.

  1. 확률이 낮은 가짜 샘플의 경우 점수가 차이난다
  2. 점수를 추정하기 위한 모델은 확산 분포의 점수만 제공한다

random gaussian noise로 분포를 “blurred”시켜서 Gradient 식에 잘 맞도록 할 수 있다.

(a)가 noise 추가 전 (b)가 noise 추가 후

Untitled

그래서 우리의 전략은 real 과 fake distribution에 대한 score denoiser model을 사용하여

generator에 noise를 추가함

Real Score

real distribution은 고정되어 있기 때문에 사전학습 모델의 copy 버전을 사용

Untitled

Dynamically-learned fake score

학습하며 generated sample이 변하기 때문에

Untitled

ϕ\phi 파라미터를 통해 조절한다. 아래는 standard denoising objective function이다.

Untitled

Distribution matching gradient update

마지막 추정 distribution matching gradient은 Eq.2 를 통해 얻을 수 있고 diffusion timestamp 에 따라 다시 작성하면

Untitled

wtw_t는 학습 dynamics를 향상시키기 위한 time-dependent scalar weight이다.

우리는 진폭의 기울기를 정규화 하도록 weighting factor를 설계하였다.

그리고 Input image와 denoised image 사이의 mean absolute error를 계산하였다.

3.3 Regression loss and final objective

distribution matching은 t가 큰 경우에는 잘 정의되어 있지만 t가 작은(noise가 적은) 경우에는 prealp_{real}이 0으로 수렴하기 때문에 sreals_{real}에 의존할 수 없다. 이를 피하기 위해 regression loss 사용함

R_Loss는 동일한 noise가 주어지면 GG와 base diffusion model 출력 pointwise distance 측정함

이를 측정하기 위해 random gaussian noise image zz 와 corresponding output yy pair의 데이터셋을 구성함. CIFAR-10 = 18 step, PNDM: 50 step … 1%의 noise만 있어도 정규화로의 작용을 잘 함

Untitled

ll = Learned Perceptual Image Patch Similarity

Final objective

최종 목적함수

Untitled

3.4 Distillation with classifier-free guidance

Calssifier-free guidance는 TTI에서 iamge의 퀄리티를 향상시키는데 많이 사용됨. 그래서 사용한다.
1. regression loss 계산을 위해 corresponding noise-output pair를 생성
2. KL divergence 계산을 위해 guided model의 mean prediction real score 생성, 반면에 fake score를 위한 공식은 수정하지 않음.

profile
딥러닝 공부중

0개의 댓글