Adaptation with Self-Evaluation to Improve Selective Prediction in LLMs

임재석·2024년 2월 1일
0

paper-study

목록 보기
8/23

1. Introduction

  • LLM is not guaranteed to be accurate for all queries

  • Understanding which queries they are reliable for is important

  • Selective Prediction : the deployment scenario for AI where humans are involved to maintain overall accuracy by reviewing AI-generated, low-confidence outputs

    • Both human and AI performance are considered together to minimize human involvement cost
    • AI should use Selective Prediction to assess the accuracy of their prediction and refrain from making wrong predictions
    • Able to say "I don't know" when its prediction is not confident
  • Selective Prediction is hard as LLM is trained to predict not the "correct" next token but only the "next" token

  • It doesn't generate a confidence score also \rightarrow obtaining confidence score from output sequence is not straightforward

  • Distinguishing the correctness from likelihood scores is a challenging

    • Using Prompt (Is the proposed answer True or False?) \rightarrow not generalized to other LLMs
    • Semantic Entropy or Self-consistency \rightarrow should generate multiple output sequence
    • Fine-tuning LLMs on target question can improve the likelihood of the ground-truth \rightarrow it is not same as minimizing wrong answers and it still has probability to generate wrong answers
  • ASPIRE : learns self-evaluate from target-task data

    • training LLMs on a subset of the training data from the QA tasks
    • define a selection score that combines the likelihood of the generated answer with the learned self-eval score to make selective predictions
    • less computationally expensive than generating multiple output sequences

2. Related Work

Selective Predictions for LLMs

  • Selective Prediction for classification (NLI) vs Selective Prediction for NLG
    • NLG task has infinite size of the possible answer set
  • Uncertainty Measure for LLMs
  • Use selective prediction to solve QA task when question is ambiguous
  • Use auxiliary model to distinguish correct predictions of QA model

Parameter Efficient Fine-Tuning (PEFT)

  • LoRA
  • Prefix Tuning
  • Soft Prompt Tuning \rightarrow used!
  • P-Tuning

3. Problem Setup

Notations

  • pretrained LLM ff for arbitary generative modeling task like QA
  • vocabulary V\mathcal{V}
  • the space of sequences of tokens V\mathcal{V}^*
  • logits of ff on vVv \in \mathcal{V} given xV\mathbf{x} \in \mathcal{V}^* is fˉ(v  x)\bar{f}(v \ | \ \mathbf{x})
  • the likelihood of the next token following x\mathbf{x} being vv is
    f(v  x):=exp(fˉ(v  x))vVexp(fˉ(v  x))f(v \ | \ \mathbf{x}) := {\exp(\bar{f} (v \ | \ \mathbf{x})) \over \sum _{v' \in \mathcal{V}} \exp (\bar{f} ( v' \ | \ \mathbf{x}))}
    (softmax!)
  • likelihood of generating y^V\hat{\mathbf{y}} \in \mathcal{V}^* given x\mathbf{x} is
    f(y^  x):=i=1y^f(yi^  x,y^[i1])f(\hat{\mathbf{y}} \ | \ \mathbf{x}) := \prod_{i=1}^{|\hat{\mathbf{y}}|}f(\hat{y_i} \ | \ \mathbf{x}, \hat{y}_{[i-1]})
    where y^=(y1^,y2^,...y^y^)\hat{\mathbf{y}} = (\hat{y_1}, \hat{y_2}, ... \hat{y}_{|\hat{\mathbf{y}}|}) and y^[i1]=(y1^,...y^i1),y^[0]=\hat{y}_{[i-1]} = (\hat{y_1}, ... \hat{y}_{i-1}), \hat{y}_{[0]} = \empty
  • This likelihood can be very small when y^|\hat{\mathbf{y}}| is very large \rightarrow normalize the likelihood
    fnorm(y^  x):=f(y^  x)1y^f_{\text{norm}}(\hat{\mathbf{y}} \ | \ \mathbf{x}) := f(\hat{\mathbf{y}} \ | \ \mathbf{x})^{{1 \over |\hat{\mathbf{y}}|}}
  • use ff to generate the output sequence by solving
    y^=arg maxy^logf(y^  x)\hat{\mathbf{y}} ^ * = \argmax_{\hat{\mathbf{y}}} \log f(\hat{\mathbf{y}} \ | \ \mathbf{x})
  • Impossible to solve exactly as the output sequence is arbitrarily long \rightarrow use decoding strategy (greedy decoding, beam search) to solve it

Evaluate Correctness

  • set of reference outputs SS

  • evaluation metric M:V×V [0,1]M : \mathcal{V}^* \times \mathcal{V}^* \rightarrow \ [0,1]

    • evaluate the similarity of the generated output y^\hat{\mathbf{y}} and the reference output yrS\mathbf{y}_r \in S
  • threshold γ\gamma

    • if maxyrSM(y^,yr)>γ\max_{\mathbf{y}_r \in S} M(\hat{\mathbf{y}}, \mathbf{y}_r) > \gamma, then the generated output is correct
  • training dataset Dtr={(xi,Si)}i=1ntr\mathcal{D}^{tr} = \{ (\mathbf{x}^i, S^i) \}_{i=1}^{n_{tr}} randomly sampled from a target task distribution

  • rejection operation \bot

  • selective predictor fs:VV{}f_s : \mathcal{V}^* \rightarrow \mathcal{V}^* \cup \{ \bot \}

    • should achieve strong selective prediction performance on test dataset
    • composed of a predictor f^:VV\hat{f} : \mathcal{V}^* \rightarrow \mathcal{V}^* and a selection scoring function g:VRg : \mathcal{V}^* \rightarrow \mathbb{R}
    • fs(x;τ)={f^(x)if g(x)τif g(x)<τf_s(\mathbf{x}; \tau) = \begin{cases} \hat{f}(\mathbf{x}) \quad &\text{if }g(\mathbf{x}) \ge \tau \\ \bot &\text{if } g(\mathbf{x}) < \tau \end{cases}
    • accuracy : the fraction of the accepted inputs where the predictions are correct
    • coverage : the fraction of the inputs that are accepted
    • Tune τ\tau to achieve a certain coverage and manage accuracy-coverage trade-off
  • use AUACC (area under the accuracy-coverage curve) to measure selective prediction performance

  • use AUROC (area under the receiver operator characteristic curve) to measure the quality of the selection score estimation

    • equivalent to the probability that a randomly chosen correct output sequence has a higher selection score than a randomly chosen incorrect output sequence

4. ASPIRE Framework

  • LLM should have self-evaluation ability
    • Previous work was only adaptable for specific LLMs
    • Colelcting some training data to employ self-evaluation

  • Start with LoRA

    • model parameters θ\theta is frozen
    • adapter θp\theta_p is added for fine-tuning and updated
    • it improves prediction accuracy and likelihood of correct output sequences \rightarrow improves selective prediction performance!
  • Fine-tune LLM to learn self-evaluation

    • use θp\theta_p to generate different answers for each example (x,y)Dtr(\mathbf{x}, \mathbf{y}) \in \mathcal{D}^{tr}

    • supposing the decoding algorithm used to generate output sequences for x\mathbf{x} is A\mathcal{A}
      where A(f,θp,x)=[y^1,...,y^k]\mathcal{A}(f, \theta_p, \mathbf{x}) = [\hat{\mathbf{y}}^1, ..., \hat{\mathbf{y}}^k]

    • choose output sequences such that f(y^j  x;θp)f(\hat{\mathbf{y}}^j \ | \ \mathbf{x}; \theta_p) is maximal

    • use metric MM to determine y^j\hat{\mathbf{y}}^j is correct
      i.e. if M(y^j,y)>γ^M(\hat{\mathbf{y}}^j, \mathbf{y}) > \hat{\gamma}, it is correct

    • use threshold γ^\hat{\gamma} different from γ\gamma for evaluation (choose sufficiently large γ^\hat{\gamma} so that the wrong outputs wouldn't be labeled as correct outputs)

    • after sampling high-likelihood outputs, tune θs\theta_s only for learning self-evaluation (θ\theta and θp\theta_p are frozen)

    • the training objective is

      minθsE(x,y)Dtr Lc+LwLc=Ey^Sc(x,y)logf(“correct”  x,y^;θp,θs)Lw=Ey^Sw(x,y)logf(“wrong”  x,y^;θp,θs)\min_{\theta_s} \mathbb{E}_{(\mathbf{x}, \mathbf{y}) \sim \mathcal{D}^{tr}} \ \mathcal{L}_c + \mathcal{L}_w \\ \mathcal{L}_c = \mathbb{E}_{\hat{\mathbf{y}} \sim S_c(\mathbf{x}, \mathbf{y})} - \log f(\text{``correct''} \ | \ \mathbf{x}, \hat{\mathbf{y}}; \theta_p, \theta_s) \\ \mathcal{L}_w = \mathbb{E}_{\hat{\mathbf{y}} \sim S_w(\mathbf{x}, \mathbf{y})} - \log f(\text{``wrong''} \ | \ \mathbf{x}, \hat{\mathbf{y}}; \theta_p, \theta_s) \\

      where Sc(x,y)S_c(\mathbf{x}, \mathbf{y}) is a set of 'correct' outputs containing the reference y\mathbf{y} and kck_c correct outputs with highest likelihood from A(f,θp,x)\mathcal{A}(f, \theta_p, \mathbf{x}), same for SwS_w (If A(f,θp,x)\mathcal{A}(f, \theta_p, \mathcal{x}) doesn't have wrong output, add a default wrong output(e.g. empty string) to SwS_w)

    • After training θs\theta_s, obtain the prediction solving

      y^=arg maxy^logf(y^  x;θp)\hat{\mathbf{y}}^* = \argmax_{\hat{\mathbf{y}}} \log f(\hat{\mathbf{y}} \ | \ \mathbf{x};\theta_p)
    • Also, the self-eval score is defined as

      P(correct  x,y^)=exp(fˉ(correct  x,y^;θp,θs))z{correct,wrong}exp(fˉ(z  x,y^;θp,θs))P(\text{correct} \ | \ \mathbf{x}, \hat{\mathbf{y}}^*) = {\exp (\bar{f}(\text{correct} \ | \ \mathbf{x}, \hat{\mathbf{y}}^*; \theta_p, \theta_s)) \over \sum_{z \in \{\text{correct}, \text{wrong} \}} \exp (\bar{f}(z \ | \ \mathbf{x}, \hat{\mathbf{y}}^*; \theta_p, \theta_s))}
    • Used Beam search decoding

    • Overall, the selection scoring function is

      g(x)=(1α)logfnorm(y^  x;θp)+αlogP(correct  x,y^)g(\mathbf{x}) = (1 - \alpha)\cdot \log f_{\text{norm}} (\hat{\mathbf{y}}^* \ | \ \mathbf{x}; \theta_p) + \alpha \cdot \log P(\text{correct} \ | \ \mathbf{x}, \hat{\mathbf{y}}^*)

      where α[0,1]\alpha \in [0,1] is a hyperparameter

5. Implementation via Soft Prompt Tuning

  • They could develop prompts that effectively stimulate self-evaluation
  • it is possible to discover these prompts through soft prompt tuning with targeted training objectives

Soft Prompt Tuning

  • given query x=(x1,...,xmq)\mathbf{x} = (x_1, ..., x_{m_q})
  • get embedding of x\mathbf{x} to form a matrix XRmq×deX \in \mathbb{R}^{m_q \times d_e}
  • soft-prompts θ~Rl×de\tilde{\theta} \in \mathbb{R}^{l \times d_e}
  • concatenate soft-prompts to query to form [θ~;X]R(mq+l)×de[\tilde{\theta}; X] \in \mathbb{R}^{(m_q + l) \times d_e}

Adapt to ASPIRE

  • update θp\theta_p with
    minθpE(x,y)Dtr1yj=1ylogf(yj  [θp;X;Y[j1]])\min_{\theta_p} \mathbb{E}_{(\mathbf{x}, \mathbf{y}) \sim \mathcal{D}^{tr}} {1 \over |\mathbf{y}|} \sum _{j=1} ^{|\mathbf{y}|} - \log f(y_j \ | \ [\theta_p ; X ; Y_{[j-1 ]}])
  • update θs\theta_s with
    minθsE(x,y)Dtr Lc+LwLc=Ey^Sc(x,y)logf(“correct”  [θp;X;Y^;θs])Lw=Ey^Sw(x,y)logf(“wrong”  [θp;X;Y^;θs])\min_{\theta_s} \mathbb{E}_{(\mathbf{x}, \mathbf{y}) \sim \mathcal{D}^{tr}} \ \mathcal{L}_c + \mathcal{L}_w \\ \mathcal{L}_c = \mathbb{E}_{\hat{\mathbf{y}} \sim S_c(\mathbf{x}, \mathbf{y})} - \log f(\text{``correct''} \ | \ [\theta_p; X; \hat{Y}; \theta_s]) \\ \mathcal{L}_w = \mathbb{E}_{\hat{\mathbf{y}} \sim S_w(\mathbf{x}, \mathbf{y})} - \log f(\text{``wrong''} \ | \ [\theta_p; X; \hat{Y}; \theta_s]) \\
  • The Inference objective becomes
    y^=arg maxy^logf(y^  x;[θp;X])\hat{\mathbf{y}}^* = \argmax_{\hat{\mathbf{y}}} \log f(\hat{\mathbf{y}} \ | \ \mathbf{x};[\theta_p; X])
  • The self-eval score becomes
    P(correct  x,y^)=exp(fˉ(correct  [θp;X;Y^;θs])z{correct,wrong}exp(fˉ(z  [θp;X;Y^;θs])P(\text{correct} \ | \ \mathbf{x}, \hat{\mathbf{y}}^*) = {\exp (\bar{f}(\text{correct} \ | \ [\theta_p; X; \hat{Y}^*; \theta_s]) \over \sum_{z \in \{\text{correct}, \text{wrong} \}} \exp (\bar{f}(z \ | \ [\theta_p; X; \hat{Y}^*; \theta_s])}

Generation Pipeline

  • obtain generated output and the likelihood for the output
  • obtain self-eval score
  • cache the states of first stage to reduce computational cost for second stage

Computational Complexity

  • At test time : O(lmax)O(l_{max})
  • Predictive entropy and semantic entropy methods : O(mlmax)O(m \cdot l_{max})

6. Experiments

  • Use decoding algorithms that can sample different high-likelihood samples is important
  • more training samples lead to enhanced performance
  • 2k samples are enough to outperform the baselines without soft-prompt tuning

6.1 Setup

  • free-form QA task : CoQA(zero-shot), SQuAD(zero-shot), TriviaQA (5-shot)
  • used 50K examples subset
  • OPT(350M, 1.3B, 2.7B, 30B), GPT-2(M, L, XL)
  • pretrained LLM and θp\theta_p trained model
  • beam-search
  • selection score g(x)g(\mathbf{x}) with PPL, Predictive Entropy, Semantic Entropy, Self-eval, P(True)
  • Rouge-L as the evaluation metric MM with relatively large γ=0.7\gamma = 0.7 (accepting wrong answer is more costly)
  • Both stage of training θp\theta_p and θs\theta_s, 10 epochs with AdamW, batch 8, lr 0.01 and cosine lr scheduling
  • for ASPIRE,
    • beam search for A\mathcal{A}
    • l=50l = 50
    • γ^=0.9\hat{\gamma} = 0.9
    • k=10k=10
    • kc=2k_c = 2
    • kw=10k_w = 10
    • α=0.25\alpha=0.25

6.2 Results

Accuracy

Methods to get selection score

  • After prompt tuning, other methods' AUACC is significantly improved as accuracy became better and PPL became more meaningful
  • ASPIRE with OPT-2.7B significantly outperforms with Self-eval and P(True) with OPT-30B
  • For Self-eval and P(True) method, the AUACC of OPT-30B is better than Adapted OPT-2.7B, it has much worse selective prediction performance
    \rightarrow self-evaluation approach is not effective for high capacity LLMs

6.3 Empirical Analyses

The effect of α\alpha

  • α=0.25\alpha=0.25 is the best recipe for normalized likelihood and the learned self-eval score
  • In practice, this value can be chosen based on the performance on the validation data

The choices of A\mathcal{A}

  • compared beam search and multinomial sampling
  • used kk highest scoring beams as the answer list (beam search)
  • tested temperature 0.1, 1.0, 2.0 for multinomial sampling

Training sample efficienty

  • Fixed the number of steps to be 50K
  • ASPIRE can significantly improve selective prediction performance even with limited number of training samples

7. Conclusion

  • Adaptation with self-evaluation to improve selective prediction in LLMs
  • Soft prompt tuning
  • Implement via other PEFT approaches and adapt to larger LLMs (Future work)
  • Didn't tested with larger and stringest LLMs (computational constraints)

8. Comment

단순히 프롬프트로 신뢰도를 찍어내는 것이 아니라, 나름의 계산과 Learning 기반으로 신뢰도를 얻어낼 수 있는게 좋았음. 다만 테스트한 모델이 좀 오래되어서, 최근의 sLLM으로도 가능한지 의문

1개의 댓글

comment-user-thumbnail
2024년 2월 29일

Hello, maybe you can try to reproduce this paper. I am very interested in this paper, but unfortunately there are some details that I don’t quite understand. By the way, your article is very well written and concise.

답글 달기

관련 채용 정보