Classifier Free Guidance

이의찬·2024년 6월 10일

Classifier Guidance

score function

score function: logp(xt)=11αˉtϵθ(xt)\nabla \log p(\mathbf{x}_t) = - \frac{1}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_{\theta}(\mathbf{x}_t)

-likelihood의 gradient 방향으로 이동하면 우도가 높아짐



logp(xty)=log(p(xt)p(yxt)p(y))\nabla \log p(x_t \mid y) = \nabla \log \left( \frac{p(x_t)p(y \mid x_t)}{p(y)} \right)
=logp(xt)+logp(yxt)logp(y)= \nabla \log p(x_t) + \nabla \log p(y \mid x_t) - \nabla \log p(y)
=logp(xt)unconditional score+logp(yxt)adversarial gradient= \underbrace{\nabla \log p(x_t)}_{\text{unconditional score}} + \underbrace{\nabla \log p(y \mid x_t)}_{\text{adversarial gradient}}



logp(xty)=logp(xt)+γlogp(yxt)\nabla \log p(x_t \mid y) = \nabla \log p(x_t) + \gamma \nabla \log p(y|x_t)



classifier free guidance

logp(xt)unconditional score+logp(yxt)adversarial gradient\underbrace{\nabla \log p(x_t)}_{\text{unconditional score}} + \underbrace{\nabla \log p(y \mid x_t)}_{\text{adversarial gradient}}



logp(yxt)=logp(xty)logp(xt)\nabla \log p(y \mid \mathbf{x}_t) = \nabla \log p(\mathbf{x}_t \mid y) - \nabla \log p(\mathbf{x}_t)



logp(xty)=logp(xt)+γ(logp(xty)logp(xt))\nabla \log p(\mathbf{x}_t \mid y) = \nabla \log p(\mathbf{x}_t) + \gamma \left( \nabla \log p(\mathbf{x}_t \mid y) - \nabla \log p(\mathbf{x}_t) \right)
=logp(xt)+γlogp(xty)γlogp(xt)= \nabla \log p(\mathbf{x}_t) + \gamma \nabla \log p(\mathbf{x}_t \mid y) - \gamma \nabla \log p(\mathbf{x}_t)
=γlogp(xty)conditional score+(1γ)logp(xt)unconditional score= \underbrace{\gamma \nabla \log p(\mathbf{x}_t \mid y)}_{\text{conditional score}} + \underbrace{(1 - \gamma) \nabla \log p(\mathbf{x}_t)}_{\text{unconditional score}}

profile
Data Science

0개의 댓글