SAFEINFER: Context Adaptive Decoding Time Safety Alignment for Large Language Models

Yuri·2025년 10월 9일

논문 리뷰

목록 보기
15/23

(AAAI 2025, Accept)

Introduction

  • LLM의 안전 정렬이 불균형할 경우 안전하지 않은 콘텐츠를 생성할 가능성이 높음
  • 특히 모델 내부 파라미터를 변경하는 모델 편집 방법이 제안되면서 안전성을 더욱 해칠 수 있음
  • 문맥 적응형 디코딩-타임(decoding-time) 안전성 정렬 방법 SafeInfer 제안

Methonology

SafeInfer는 두 단계로 구성됨

  1. Safety Amplification (SA) 단계
    • Activation patching을 통해 LLM 내에서 영향력 있는 어텐션 헤드 세트 AA를 식별
      • Activation patching: LLM에서 안전한 QA 및 유해한 QA를 각각 실행하고, 각 어텐션 헤드의 activation을 서로 바꿔가면서 출력의 변화를 추적하여 해당 어텐션 헤드의 영향력 평가
    • 안전한 데모 데이터셋 DsfD_{sf}에서 {(q1, a1), (q2, a2), ..., (qn, an), qn+1} 형태의 프롬프트 세트 PP를 구성. 여기서 qq는 안전하지 않은 질문이고 aa는 안전한 답변
    • 각 어텐션 헤드 attnljattn_{lj} (ll은 레이어, jj는 위치)에 대해 프롬프트 세트 PP의 representations 평균을 계산하여 safety conditioned activations attnljattn'_{lj}를 구함
      attnlj=1PpPattnlj(p)attn^\prime_{lj} = \frac{1}{|P|} \sum_{p \in P} attn_{lj}(p)
    • attnljAattn_{lj} \in A에 대해 계산된 attnljattn^\prime_{lj}를 합산하여 단일 벡터인 안전 증폭 벡터 SVSV를 생성
    • SVSV를 대상 모델 MtM_t의 특정 레이어 ll의 은닉 상태 hlh_l에 통합하여 업데이트된 은닉 상태 hlh^\prime_l 및 업데이트된 은닉 상태를 가진 모델 MtM^\prime_t을 구함. 여기서 γ\gamma는 하이퍼파라미터
      hl=hl+γSVh^\prime_l = h_l + \gamma \cdot SV
  2. Safety-Guided Decoding Strategy (sGDS) 단계
    • 유해한 질문-답변 쌍으로 구성된 데이터셋 DusfD_{usf}를 사용하여 동일한 LLM을 fine-tuning하고 유해 모델 MusfM_{usf} 구성
    • MtM^\prime_t의 출력 분포를 보존하면서 **MusfM_{usf}의 유해한 경향을 완화하기 위해 output probabilities 수정
    • Union 연산자 사용하여 MtM^\prime_t MusfM_{usf}의 output distribution을 통합하는 combined distribution CC를 구함
    • Union 연산자는 두 분포 중 하나라도 특정 토큰 xx에 높은 확률이라면 결과 분포도 해당 토큰에 높은 확률을 반영하도록 비선형 결합. 여기서 I(x)I(x)는 인디케이터 함수
      D[I1]KL(CMt)+D[I2]KL(CMusf)whereI1(x)=[Mt(x)>Musf(x)]I2(x)=1I1(x)D_{[I_1]KL}(C || M^\prime_t) + D_{[I_2]KL}(C || M_{usf}) \\ where \quad I_1(x) = [M^\prime_t(x) > M_{usf}(x)] \\ I_2(x) = 1 - I_1(x)
    • KL-divergence로 C(x)C(x)를 구함. 여기서 σ\sigma는 standard softmax 함수
      C(x)=σ(max(logMt(x),logMusf(x)))C(x) = \sigma(\max(\log M^\prime_t(x), \log M_{usf}(x)))
    • MtM^\prime_t의 유해성을 줄이기 위해, MtM^\prime_t의 분포에서 특정 토큰들의 영향을 제한함으로써 안전한 출력 분포 MtsM^s_t를 얻음. 여기서 λ\lambda는 하이퍼파라미터
      Mts=Mtλσ(max(logMt,logMusf))=MtλCM^s_t = M^\prime_t - \lambda \cdot \sigma(\max(\log M^\prime_t, \log M_{usf})) = M^\prime_t - \lambda \cdot C

image.png

Experimental Result

image.png image.png

0개의 댓글