[Paper Review] RL with KL penalties is better viewed as Bayesian inference

stat._.jun·5일 전

Case : Language model

  • X\mathcal{X} : set of sequences of tokens from some vocabulary
  • π\pi : Language Model (LM)
  • π0\pi_0 : Pretrained LM
  • π(x)\pi(x) : probability of sequence xXx \in \mathcal{X}
  • r(x)r(x) : Reward (Human preferences)

Reinforcement Learning with Human Feedbacks (RLHF)에서 흔히 사용하는 목적식은 아래와 같다.

maxπ{Exπ[r(x)]βKL(ππo)}\max_{\pi} \left\{ \mathbb{E}_{x \sim \pi}[r(x)] - \beta KL(\pi \| \pi_o) \right\}

위 목적식을 최대화하는 π\pi^*는 Donsker-Varadahan Lemma (or duality formula)에 의해 아래의 Gibbs Posterior로 달성됨이 알려져있다. (이를 위해 논문의 Appendix에서도 짧은 가정을 두었다.)

π(x)exp(r(x)/β)πo(x)\pi^*(x) \propto \exp( r(x) / \beta) \pi_o(x)

그런데, 일반적으로 Posterior는 구하는게 불가능하다. 전체 Seqeunce space X\mathcal{X}에 대해서 exponential tilting하는게 어려우니깐... 그래서 자연스러운 접근으로 θ\theta로 parametrized된 LM πθ\pi_{\theta}를 고려할 수 있고, 아래와 같이 목적식을 다시 표현할수 있다.

maxθ{Exπθ[r(x)]βKL(πθπo)}\max_{\theta} \left\{ \mathbb{E}_{x \sim \pi_{\theta}}[r(x)] - \beta KL(\pi_{\theta} \| \pi_o) \right\}

논문의 그림을 보면 더 이해가 잘되는 것 같다.

Without KL Penalty

KL Term을 빼면, Expected Reward Maximization의 문제가 된다. 따라서 Reward가 가장 큰 Sequence에 모든 mass를 두는 dirac delta로 붕괴한다.

Key contribution?

제목 그대로, RLHF를 Bayesian의 언어로 번역하여 다시 프레이밍한 것이 주된 기여같다.

RLHF 관련 글을 볼 때, π0\pi_0가 사전분포의 역할을 하고 loss를 -Reward로 사용하는 Variational Bayes로 해석할 수 있는지 궁금했는데, 이렇게 정리된 논문이 있었다.

Korbak, Tomasz, Ethan Perez, and Christopher Buckley. "RL with KL penalties is better viewed as Bayesian inference." Findings of the Association for Computational Linguistics: EMNLP 2022. 2022.

0개의 댓글