[Diffusion] DDIM 논문 리뷰: DENOISING DIFFUSION IMPLICIT MODELS

whateverpartysover·2025년 3월 25일

Diffusion

목록 보기
2/5
post-thumbnail

한동안 면접 러쉬로 인해 중단되었던 공부를 재개.. 잠시 정리를 미루어놓았던 DDIM과 Flow Matching 논문을 일단 정리해놓아야겠다.

DDIM은 사실상 DDPM에 수식 장난질을 추가한 것이기 때문에
DDPM 논문 리뷰 를 보고 오세요..

사실 DDIM의 실제 구현만 떼놓고 보면 사실상 DDPM 스텝 줄이는거말고는 없다.
하지만 논문자체로 Diffusion Model의 수학적인 본질과 이후 연구에서까지의 이를 해결하는 과정에서 인사이트를 줄 수 있는 논문인것 같다.

논문: https://arxiv.org/abs/2010.02502


Abstract 베끼기

DDPM은 성능이 좋지만 느리다.
이유: 생성 단계에서 엄청 많은 step 에 대한 마르코프 체인 과정이 일어나야한다.
해결: 이 과정을 deterministic하게 하는 non Markovian process 도입하여 빨라지게 한다.

  • 10~50배 빠른 생성속도
  • 연산량과 샘플 품질 trade-off를 조절 가능
  • latent space에서 semantically meaningful interpolation 수행 가능

기술 개요

Variational Inference For Non-Markovian Forward Process

DDPM의 loss LγL_\gammaq(xtx0)q(x_t|x_0)와 같은 marginals에만 의존하고, q(x1:Tx0)q(x_{1:T}|x_0)와 같은 joint distribution에는 직접적으로 의존하지 않는다

저자들은 위 사실에 집중했다고 한다.
분명 처음에는 Markov chain이니까 joint distribution에 관한 식이었을 것인데, 뭐 상수니까 빼고 실험적인 이유로 빼고 하다보니 어느시 marginals만 식에 남았네?

그렇기 때문에 marginals만 공유를 한 채로, joint distribution을 재설계하여 샘플링 속도를 빠르게 해볼수 있다. 어떻게? non-Markovian으로 전환시키는 것을 통해!

Non-Markovian Forward Process

냅다 들이미는 새로운 joint

qσ(x1:Tx0):=qσ(xTx0)t=2Tqσ(xt1xt,x0)q_\sigma(x_{1:T} | x_0) := q_\sigma(x_T | x_0) \prod_{t=2}^{T} q_\sigma(x_{t-1} | x_t, x_0)

그리고 식에서 각 항의 의미는 다음과 같다.

marginal

qσ(xTx0)=N(αˉTx0, (1αˉT)I)q_\sigma(x_T | x_0) = \mathcal{N}(\sqrt{\bar{\alpha}_T} x_0,\ (1 - \bar{\alpha}_T) I)

기존 DDPM의 marginal과 동일하게 가져간다

역방향 조건부 분포

qσ(xt1xt,x0)=N(αˉt1x0+1αˉt1σt2xtαˉtx01αˉt, σt2I)q_\sigma(x_{t-1} | x_t, x_0) = \mathcal{N}\left( \sqrt{\bar{\alpha}_{t-1}} x_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \cdot \frac{x_t - \sqrt{\bar{\alpha}_t} x_0}{\sqrt{1 - \bar{\alpha}_t}}, \ \sigma_t^2 I \right)

일단 조건부에 x0x_0가 붙었으니 결국 non-Markovian
평균의 식이 굉장히 복잡하지만, 이는 DDPM의 marginal을 유지할 수 있도록 선택된것이며, 이렇게 했을 때 marginal이 유지된다는 lemma는 appendix에서 증명된다.

암튼 이 식들을 가지고 Bayes rule 써서 forward process를 정의하면 되겠다. (수식 귀찮아)

그리고 여기서 σ\sigma 가 0에 가까워질 수록 분산도 0에 가까워지면서 deterministic하게 된다는 점을 기억합시다.

Generative Process And Unified Variational Inference Objective

generative process

생성 프로세스 pθ(x0:T)p_\theta(x_{0:T}) 를 정의해봅시다.
기존 DDPM에서는 그냥 통짜로 이걸 학습 때렸지만, 이제 미리는 미리 정의해둔 joint를 활용할수가 있다. (pθ(t)(xt1xt)p_\theta^{(t)}(x_{t-1}|x_t)가 관심사이지만 qσ(xt1xt,x0)q_\sigma(x_{t-1}|x_t,x_0)을 알고 있으므로.. ㅎㅎ)

요 프로세스는 xtx_t 를 가지고 x0x_0를 예측해놓고 이거로 qσ(xt1xt,x0)q_\sigma(x_{t-1}|x_t,x_0)를 계산해서 xt1x_{t-1} 를 샘플링하는 형태로 이루어질 수 있다.

x0x_0의 예측은 기존 ddpm 수식을 활용해서 아래와 같이 나타낼수 있고

fθ(t)(xt):=xt1αtϵθ(t)(xt)αtf_\theta^{(t)}(x_t) := \frac{x_t - \sqrt{1 - \alpha_t} \cdot \epsilon_\theta^{(t)}(x_t)}{\sqrt{\alpha_t}}

고렇다면 우리가 필요한 생성 프로세스 pp는 결국 다음과 같이 마무리 칠 수 있다.

pθ(t)(xt1xt)={N(fθ(1)(x1), σ12I)if t=1qσ(xt1xt, fθ(t)(xt))otherwisep_\theta^{(t)}(x_{t-1} | x_t) = \begin{cases} \mathcal{N}(f_\theta^{(1)}(x_1),\ \sigma_1^2 I) & \text{if } t = 1 \\ q_\sigma(x_{t-1} | x_t,\ f_\theta^{(t)}(x_t)) & \text{otherwise} \end{cases}

t=1일때의 의미는.... 이해는 못하겠지만... 암튼 실제 DDIM에선 샘플링시 σ\sigma를 0으로 쓰기때문에 의미는 없고 (스포 ㅎ) 일반화를 위한것이라고 보면 될것 같다

objective

loss 자체는 역시 변분추론 해갖구 하면

Jσ(ϵθ):=Ex0:Tqσ[logqσ(x1:Tx0)logpθ(x0:T)]=Ex0:Tqσ[logqσ(xTx0)+t=2Tlogqσ(xt1xt,x0)t=1Tlogpθ(t)(xt1xt)logpθ(xT)]J_\sigma(\epsilon_\theta) := \mathbb{E}_{x_{0:T} \sim q_\sigma} \left[ \log q_\sigma(x_{1:T} | x_0) - \log p_\theta(x_{0:T}) \right] \\ = \mathbb{E}_{x_{0:T} \sim q_\sigma} \left[ \log q_\sigma(x_T | x_0) + \sum_{t=2}^T \log q_\sigma(x_{t-1} | x_t, x_0) - \sum_{t=1}^T \log p_\theta^{(t)}(x_{t-1} | x_t) - \log p_\theta(x_T) \right]

흠 근데 목적함수가 σ\sigma에 의존을 하는걸 보니... σ\sigma 선택을 잘 해서 좋은 모델 학습하기가 빡셀듯 싶은데요?

여기서 등장하는

Theorem 1
For all σ>0\sigma > 0, there exists γR>0T\gamma \in \mathbb{R}_{>0}^T and CRC \in \mathbb{R}, such that Jσ=Lγ+CJ_\sigma = L_\gamma + C

LγL_\gamma는 DDPM에서 최종적으로 도출된 surrogate objective
아니 gamma가 뭔데요 시팔

이전 loss 유도 할때 Lt1L_{t-1}항 (실제로 살아남는 항)

Lt1=Eq[12σt2μ~t(xt,x0)μθ(xt,t)2]+CL_{t-1} = \mathbb{E}_q \left[ \frac{1}{2\sigma_t^2}|| \tilde \mu_t(x_t,x_0) - \mu_\theta(x_t,t) ||^2 \right] + C

이 요런식으로 12σt2\frac{1}{2\sigma_t^2} 가 붙어있었고 이게 γ\gamma 인데요... 실험적으로 1로 취급해도 잘된다구 해서 빼버렸었읍니다...

암튼 theorem을 풀어서 이야기 하자면 σ\sigma는 어떻게 놓더라도 DDPM loss에 대응되는 형태로 풀어지기 때문에 사실상 DDPM 방식으로 학습해도 요 non Markvian 모델을 쓸 수 있다.

이 또한 굉장히 중요해보이지만 증명은 appendix 참조..
슬쩍 보았을때 KLD 형태의 loss 식을 잘 만져서 공분산 같은 두 분포의 KLD로 만들고 이걸 MSE 형태로 바꿔서 어찌저찌 유도를 해냈다

추가로 γ\gamma 는 만약 모델 ϵθ(t)\epsilon_\theta^{(t)}의 파라미터 θ\theta가 서로 다른 timestep tt 간에 공유되지 않는다면,
각 항을 개별적으로 최적화할 수 있기 때문에, 최적해는 가중치 γ\gamma에 의존하지 않게 된다.

이거 이해 안됐어서 조금 더 부연을 하자면,

Lγ=t=1TγtLosst(θ(t))L_\gamma = \sum_{t=1}^T \gamma_t \cdot \text{Loss}_t(\theta^{(t)})

요런 식으로 loss가 쓰일수가 있을텐데 tt 시점에서는 그냥 그 시점에서의 최적값만 찾으면 전체도 최적화 되기 때문에 γ\gamma가 최적점 자체에 직접적인 영향을 끼치진 않을것이란 뜻이다.

암튼 요렇게 밑밥을 깔아줬으니 이전 논문에서 L1L_1을 쓴게 정당화가 됐고 여기서도 Lemma 따라서 JσJ_\sigma는 알아서 정해질테니 L1L_1 쓰겠다 라는 것임

Sampling From Generalized Generative Processes

앞에서 보았듯 이제 L1L_1 학습은 Markovian process 학습만을 나타내는 것이 아니고, σ\sigma로 파라미터화된 Non-Markovian process를 학습하는 것이 된다.
즉 학습은 그냥 DDPM으로 하고, σ\sigma를 달리해가면서 샘플 생성을 다양한 목적에 맞게 진행할수 있게 된다.

Denoising Diffusion Implicit Models

xt1=αt1(xt1αtϵθ(t)(xt)αt)“predicted x0+1αt1σt2ϵθ(t)(xt)“direction pointing to xt+σtϵtrandom noisex_{t-1} = \sqrt{\alpha_{t-1}} \underbrace{ \left( \frac{x_t - \sqrt{1 - \alpha_t} \cdot \epsilon_\theta^{(t)}(x_t)}{\sqrt{\alpha_t}} \right) }_{\text{“predicted } x_0\text{”}} + \underbrace{ \sqrt{1 - \alpha_{t-1} - \sigma_t^2} \cdot \epsilon_\theta^{(t)}(x_t) }_{\text{“direction pointing to } x_t\text{”}} + \underbrace{ \sigma_t \cdot \epsilon_t }_{\text{random noise}}

사악 종합해서 xtx_t로부터 xt1x_{t-1}를 뽑는 식을 위 형태로 쓸 수 있다.

위 식의 특성을 뜯어보자면

1) DDPM 학습 그대로 해도 된다.

σ\sigma를 다르게 놓아서 생성 과정 자체를 달리 할수는 있지만, 여기서 활용되는 모델 즉 ϵθ\epsilon_\theta 는 똑같은걸 써도 된다.

2) 일반화된 식이다.

σt=1αt1αt11αt1αt\sigma_t = \sqrt{ \frac{1 - \alpha_t}{1 - \alpha_{t-1}} } \cdot \sqrt{ 1 - \frac{\alpha_{t-1}}{\alpha_t} }

로 두면, forward process는 Markovian process가 되고, 생성 과정도 DDPM과 똑같다.

3) σ=0\sigma=0 일 때:

랜덤성이 부여되는 노이즈항이 완전히 사라지고, xt1x_{t-1}xtx_tx0x_0 에 대해 deterministic해진다.
요걸 이제 DDIM이라고 한다. 참고로 발음은 /d:Im/ 이라고 한다. (응 난 디디아이엠이라 할거야)
아무튼 요건 implicit probabilistic model 의 한 형태가 된다고 하는데 이 논문 모르고 보기도 귀찮아서 챗지피티한테 물어보았다.

📌 정의:
명시적인 확률분포(explicit distribution)를 정의하지 않고,
대신 샘플링 과정을 통해서만 간접적으로 데이터 분포를 모델링하는 모델을 의미해.

즉, 수학적으로 "이게 우리가 모델링하는 확률분포다" 라는 p(x)를 정확하게 써줄 수는 없지만,
그 분포로부터 샘플을 생성하는 절차만 정의되어 있는 모델이야.

음 그렇군요

아무튼 요 시그마 제로! 이게 이제 stochastic한 생성과정을 deterministic하게 만들어 주었다는것이 포인트입니다
요렇게 deterministic한 샘플링 과정 덕분에 중간 중간 그냥 스킵하고 짬프를 뛰어도 샘플링을 해낼 수가 있게 되고 생성 프로세스를 줄일수가 있더라~하는 이야기

Relevance To Neural ODES

요 파트가 이해하기도 어렵기도하고 논문 나왔을 당시로서는 큰 의미도 모르겠고 해서 예전엔 제끼면서 넘어갔던 부분이고, 다른 블로그 리뷰글에서도 잘 다루어지지는 않는것 같다. 하지만 글 초반에서 언급했던 이론적 인사이트의 핵심이 될 부분이 아닌가 싶다.

DDIM 샘플링 식을 요렇게 쓸수 있다.

xtΔtαtΔt=xtαt+(αtΔt1αtΔtαt1αt)ϵθ(t)(xt)\frac{x_{t - \Delta t}}{\sqrt{\alpha_{t - \Delta t}}} = \frac{x_t}{\sqrt{\alpha_t}} + \left( \sqrt{\frac{\alpha_{t - \Delta t}}{1 - \alpha_{t - \Delta t}}} - \sqrt{\frac{\alpha_t}{1 - \alpha_t}} \right) \cdot \epsilon_\theta^{(t)}(x_t)

요기서 reparam 해주면

1αασ,xαxˉ\frac{\sqrt{1 - \alpha}}{\sqrt\alpha} \Rightarrow \sigma, \quad \frac{x}{\sqrt{\alpha}} \Rightarrow \bar{x}
dxˉ(t)dσ(t)=ϵθ(t)(xˉ(t)σ(t)2+1)\frac{d\bar{x}(t)}{d\sigma(t)} = \epsilon_\theta^{(t)} \left( \frac{\bar{x}(t)}{\sqrt{\sigma(t)^2 + 1}} \right)

initial condition

x(T)N(0, σ(T))for a very large σ(T)x(T) \sim \mathcal{N}(0,\ \sigma(T)) \quad \text{for a very large } \sigma(T)

(사실 눈으로는 직접적으로 안보이는데 암튼 이렇대요)

암튼 하고자 하는 말은 이 step 간 업데이트가 오일러 방식으로 ODE 푸는 것과 굉장히 유사해지고, 이를 연속화하면 걍 ODE 문제가 되더라...

기존 DDPM이 결국 노이즈에 의존하는 SDE 문제였는데, 이를 deterministic하게 구조를 바꿈으로서 ODE 형태에 가깝게 풀어냈다는 것이되고, 이러한 특성이 빠른 특성에 기여하지 않았을까 하는 해석도 되지 않을까오ㅛ 아님말고

아무튼 요 SDE 문제였던 기존 diffusion을 ODE 문제로 전환한다는 것의 의미!
Flow Matching에서 이어집니다...

profile
인공지능 못해요

0개의 댓글