DPO : Direct Preference Optimization: Your Language Model is Secretly a Reward Model

ingeol·2023년 9월 7일
0

논문리뷰

목록 보기
19/40
post-thumbnail

Intro

해당논문 이전까지는 human preference를 높이기위한 방법으로 강화학습을 적용했을때 가장 성공적인 결과가 나왔음

+) RM의 경우 본인이 직접 뭔가 작성하는 것보다 남들이 작성해놓은 것을 보고 평가하는 것이 더 일관성있는 어노테이션이 가능 → 이 rlhf는 최근 llm에 적용되는 방식들은 다 preference데이터를 사용해 답변이 좋으면 좋은 스코어를 주는 형식의 RM 모델 형식으로 만들어 1. SFT, 2. make RM, 3. LM policy 최적화 하는 단계로 진행된다고 생각하면 된다.

Demonstration data 를 많이 만들어서 지도학습으로 finetuning하는 경우 훈련된 annoter라도 차이가 있고 모델입장에서는 annotation품질이 낮은 것도 학습이 진행되는 문제점이 존재함.

DPO : 리워드 모델링 과정 생략 preference 데이터를 directly RM 최적화에 사용하는 방법 제안


reward loss :

maxπExD,yπ[r(x,y)]βDKL[π(yx)πref(yx)]=maxπExDEyπ(yx)[r(x,y)βlogπ(yx)πref(yx)](1)DKL=iP(i)logP(i)/Q(i)=minπExDEyπ(yx)[logπ(yx)πref(yx)1βr(x,y)](2)\begin{aligned} \max_{\pi} \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi} & {[r(x, y)]-\beta \mathbb{D}_{\mathrm{KL}}\left[\pi(y \mid x) \| \pi_{\mathrm{ref}}(y \mid x)\right] } \\& =\max_{\pi} \mathbb{E}_{x \sim \mathcal{D}} \mathbb{E}_{y \sim \pi(y \mid x)}\left[r(x, y)-\beta \log \frac{\pi(y \mid x)}{\pi_{\mathrm{ref}}(y \mid x)}\right] \cdots (1) \\ & \because D_{KL} = \sum_{i}P(i) \log P(i)/Q(i) \\ & =\min _{\pi} \mathbb{E}_{x \sim \mathcal{D}} \mathbb{E}_{y \sim \pi(y \mid x)}\left[\log \frac{\pi(y \mid x)}{\pi_{\mathrm{ref}}(y \mid x)}-\frac{1}{\beta} r(x, y)\right] \cdots (2) \\ \end{aligned}

(2) 에서는 -1 X 1/ β\beta,

괄호 안에만보면

=log(π(yx)πref(yx))log(exp(1β)r(x,y))=log(π(yx)πref(yx)log(exp(1β)r(x,y)))=log(π(yx)/Z(x)πref(yx)log(exp(1β)r(x,y))/Z(x))=log(π(yx)πref(yx)log(exp(1β)r(x,y))/Z(x))logZ(x)= \log \left( \frac{\pi(y|x)}{\pi_{ref}(y|x)} \right) - \log( \exp(\frac{1}{\beta})r(x,y)) \\ = \log \left( \frac{\pi(y|x)}{\pi_{ref}(y|x) \log( \exp(\frac{1}{\beta})r(x,y))} \right) \\ = \log \left( \frac{\pi(y|x)/Z(x)}{\pi_{ref}(y|x) \log( \exp(\frac{1}{\beta})r(x,y))/Z(x)} \right) \\ = \log \left( \frac{\pi(y|x)}{\pi_{ref}(y|x) \log( \exp(\frac{1}{\beta})r(x,y))/Z(x)} \right) -\log Z(x)\\

식 정리 그 결과

let)Z(x)=yπref(yx)exp(1/β×r(x,y))let)π(yx)=1Z(x)πref(yx)exp(1βr(x,y))(3)let)\quad Z(x) = \sum_{y} \pi_{ref} (y|x)\exp(1/\beta \times r(x,y)) \\ let) \quad \pi^{*}(y|x) = \frac{1}{Z(x)} \pi_{\mathrm{ref}}(y \mid x) \exp \left(\frac{1}{\beta} r(x, y)\right) \dots (3)
=minπExDEyπ(yx)[logπ(yx)1Z(x)πref(yx)exp(1βr(x,y))logZ(x)]=\min_{\pi} \mathbb{E}_{x \sim \mathcal{D}} \mathbb{E}_{y \sim \pi(y \mid x)}\left[\log \frac{\pi(y \mid x)}{\frac{1}{Z(x)} \pi_{\mathrm{ref}}(y \mid x) \exp \left(\frac{1}{\beta} r(x, y)\right)}-\log Z(x)\right]

두 개의 정의를 위의 식에 대입하면

=minπExDEyπ(yx)[logπ(yx)π(yx)logZ(x)]=\min_{\pi} \mathbb{E}_{x \sim \mathcal{D}} \mathbb{E}_{y \sim \pi(y \mid x)} \left[ \log \frac{\pi (y|x)}{\pi^{*}(y|x)} -\log Z(x) \right]

해당 식을 kl divergence로 정리 할 수 있음,, Z(x)Z(x)가 + 로 바뀌는지는 모르겠음,,,

=minπExD(DKL[π(yx)π(yx)]+Z(x))(4)=\min_{\pi} \mathbb{E}_{x \sim \mathcal{D}} (\mathbb{D}_{\mathrm{KL}}\left[\pi(y \mid x) \| \pi^{*}(y \mid x) \right] + Z(x)) \dots (4)

해당 식 전체를 minimize하는 경우는 KL divergence = 0 인경우, 즉 target의 optim solution ⇒ π\pi^{*} 다시 말해 결국 대답이 같아 지는게 optim solution이 된다.

π(yx)=π(yx)=1Z(x)πref(yx)exp(1βr(x,y))\pi(y|x) = \pi^{*}(y|x) = \frac{1}{Z(x)} \pi_{\mathrm{ref}}(y \mid x) \exp \left(\frac{1}{\beta} r(x, y)\right)

4번 수식으로 KL divergence 가 같을 때 loss가 최소이므로 3번식으로 돌아가서 위와같이 쓸 수 있다.

r(x,y)r(x,y) 에 의한 식( reward 함수에 관한 수식으로 변경해주면)

r(x,y)=βlogπ(yx)Z(x)πref(yx)=βlogπ(yx)πref(yx)+βlogZ(x)(5)r^{*}(x,y) = \beta \log \frac{\pi^{*}(y|x)Z(x)}{\pi_{ref}(y|x)} = \beta \log \frac{\pi^{*}(y|x)}{\pi_{ref}(y|x)} + \beta \log Z(x) \dots (5)

key idea : 로그 확률분포의 비율 (policy 와 ref model ) 이게 의미하는 것은 train our policy, 즉reward function을 의미하게 된다. 또한 같은 의미로 reward 최적화를 만드는 것이므로 human preference를 만족시키게 된다.


B-T (bradley-terry preference model ,1952에 나온 논문)

p(y1y2x)=exp(r(x,y1))exp(r(x,y1))+exp(r(x,y2))(6)p^{*}\left(y_{1} \succ y_{2} \mid x \right) = \frac{\exp \left(r^{*} \left(x, y_{1} \right)\right)} {\exp \left(r^{*} \left (x, y_{1} \right) \right) + \exp \left(r^{*}\left(x, y_2\right)\right)} \dots (6)

5번 수식의 최적 reward를 선호도 모델에 대입해 y1> y2 ( y1이 우수한 대답) 을

p(y1y2x)=exp(βlogπ(y1x)πref(y1x)+βlogZ(x))exp(βlogπ(y1x)πref(y1x)+βlogZ(x))+exp(βlogπ(y2x)πref(y2x)+βlogZ(x))=11+exp(βlogπ(y2x)πref(y2x)βlogπ(y1x)πref(y1x))=σ(βlogπ(y1x)πref(y1x)βlogπ(y2x)πref(y2x))\begin{aligned} p^{*}\left(y_{1} \succ y_{2} \mid x\right) & = \frac{\exp \left( \beta \log \frac{\pi^{*}\left(y_{1} \mid x\right)}{\pi_{\mathrm{ref}} \left(y_{1} \mid x\right)}+ \beta \log Z(x)\right)} {\exp \left( \beta \log \frac{\pi^{*}\left(y_{1} \mid x\right)}{\pi_{\mathrm{ref}}\left(y_{1} \mid x\right)}+\beta \log Z(x)\right)+\exp \left(\beta \log \frac{\pi^{*}\left(y_{2} \mid x\right)}{\pi_{\mathrm{ref}}\left(y_{2} \mid x\right)}+\beta \log Z(x)\right)} \\ & =\frac{1}{1+\exp \left(\beta \log \frac{\pi^{*} \left(y_{2} \mid x \right)}{\pi_{ref }\left(y_{2} \mid x\right)}-\beta \log \frac{\pi^{*}\left(y_{1} \mid x\right)}{\pi_{\mathrm{ref}}\left(y_{1} \mid x\right)}\right)} \\ & =\sigma \left( \beta \log \frac{\pi^{*} \left(y_{1} \mid x\right)}{\pi_{\mathrm{ref}}\left(y_{1} \mid x\right)}-\beta \log \frac{\pi^{*}\left(y_{2} \mid x\right)}{\pi_{\mathrm{ref}}\left(y_{2} \mid x\right)}\right) \end{aligned}

맨 위에 식 에서 각 로그를 a,b,c로 각각 치환했을 때 아래와 같이 쉽게 정리할 수 있음

ea+bea+b+ec+b=eaea+ec=11+eca\frac{e^{a+b}}{e^{a+b} + e^{c+b}} = \frac{e^{a}}{e^{a} + e^{c}} = \frac{1}{1 + e^{c-a}}

또한 sigmoid는 1/1+ex{1}/{1 + e^{-x}} 이기 때문에 σ(ac)\sigma(a-c) 로 나타낼 수 있다.


마지막,,, 미분

LDPO(πθ;πref)=E(x,yw,yl)D[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))]\mathcal{L}_{\mathrm{DPO}}\left(\pi_\theta ; \pi_{\mathrm{ref}}\right)=-\mathbb{E}_{\left(x, y_w, y_l\right) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_\theta\left(y_w \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}-\beta \log \frac{\pi_\theta\left(y_l \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)}\right)\right]

시그마 안에 부분 (βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))\left(\beta \log \frac{\pi_\theta\left(y_w \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}-\beta \log \frac{\pi_\theta\left(y_l \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)}\right) 을 싹다 uu로 치환 후 미분

참고1 : σ(u)=σ(u)(1σ(u))\sigma^{\prime}(u) = \sigma(u)(1-\sigma(u))
참고2 : 1σ(u)=σ(u)1- \sigma(u) = \sigma(-u),

logσ(u)dudθ=σ(u)σ(u)u=σ(u)(1σ(u))σ(u)u=(1σ(u))u=σ(u)u\log \sigma(u) \frac{du}{d \theta} = \frac{\sigma^{\prime}(u)}{\sigma(u)}u^{\prime} = \frac {\sigma(u)(1-\sigma(u))}{\sigma(u)}u^{\prime}= (1-\sigma(u)) u^{\prime} = \sigma(-u)u^{\prime}

uu^{\prime}은 속미분 결과,,,, (아래와 같이 또 치환) u=r^θ(x,yw)r^θ(x,yl)u = \hat r_{\theta}(x,y_{w}) - \hat r_{\theta}(x,y_{l})

r^θ(x,yw)=βlogπθ(ywx)πref(ywx),r^θ(x,yl)=βlogπθ(ylx)πref(ylx),\hat r_{\theta}(x,y_{w}) = \beta \log \frac{\pi_{\theta}\left(y_{w} \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}, \quad \hat r_{\theta}(x,y_{l}) = \beta \log \frac{\pi_{\theta}\left(y_{l} \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)},
θLDPO(πθ;πref)=βE(x,yw,yl)D[σ(r^θ(x,yl)r^θ(x,yw))[θlogπ(ywx)θlogπ(ylx)]]\begin{aligned}& \nabla_\theta \mathcal{L}_{\mathrm{DPO}}\left(\pi_\theta ; \pi_{\mathrm{ref}}\right)= \\& -\beta \mathbb{E}_{\left(x, y_w, y_l\right) \sim \mathcal{D}}\left[\sigma\left(\hat{r}_\theta\left(x, y_l\right)-\hat{r}_\theta\left(x, y_w\right)\right) \left[\nabla_\theta \log \pi\left(y_w \mid x\right)-\nabla_\theta \log \pi\left(y_l \mid x\right)\right]\right]\end{aligned}

loss 변화의 해석

βE(x,yw,yl)Dσ(r^θ(x,yl)r^θ(x,yw))\beta \mathbb{E}{\left(x, y_w, y_l\right) \sim \mathcal{D}}\sigma\left(\hat{r}\theta\left(x, y_l\right)-\hat{r}_\theta\left(x, y_w\right)\right) : weight by how incorrect the model is

[θlogπ(ywx)θlogπ(ylx)]\left[\nabla_\theta \log \pi\left(y_w \mid x\right)-\nabla_\theta \log \pi\left(y_l \mid x\right)\right] : increase y_w prob, decrease y_l prob

최종결과

θLDPO(πθ;πref)=θE(x,yw,yl)D[logσ(βlogπθ(ylx)πref(ylx)βlogπθ(ywx)πref(ywx))]\nabla_\theta \mathcal{L}_{\mathrm{DPO}}\left(\pi_\theta ; \pi_{\mathrm{ref}}\right)=-\nabla_\theta \mathbb{E}_{\left(x, y_w, y_l\right) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_\theta\left(y_l \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)}-\beta \log \frac{\pi_\theta\left(y_w \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}\right)\right]

실험결과

step 증가, KL 증가에 변동성이 적음(ppo에 비해)

0개의 댓글