SAFEINFER: Context Adaptive Decoding Time Safety Alignment for Large Language Models
(AAAI 2025, Accept)
Introduction
- LLM의 안전 정렬이 불균형할 경우 안전하지 않은 콘텐츠를 생성할 가능성이 높음
- 특히 모델 내부 파라미터를 변경하는 모델 편집 방법이 제안되면서 안전성을 더욱 해칠 수 있음
- 문맥 적응형 디코딩-타임(decoding-time) 안전성 정렬 방법 SafeInfer 제안
Methonology
SafeInfer는 두 단계로 구성됨
- Safety Amplification (SA) 단계
- Activation patching을 통해 LLM 내에서 영향력 있는 어텐션 헤드 세트 A를 식별
- Activation patching: LLM에서 안전한 QA 및 유해한 QA를 각각 실행하고, 각 어텐션 헤드의 activation을 서로 바꿔가면서 출력의 변화를 추적하여 해당 어텐션 헤드의 영향력 평가
- 안전한 데모 데이터셋 Dsf에서
{(q1, a1), (q2, a2), ..., (qn, an), qn+1} 형태의 프롬프트 세트 P를 구성. 여기서 q는 안전하지 않은 질문이고 a는 안전한 답변
- 각 어텐션 헤드 attnlj (l은 레이어, j는 위치)에 대해 프롬프트 세트 P의 representations 평균을 계산하여 safety conditioned activations attnlj′를 구함
attnlj′=∣P∣1p∈P∑attnlj(p)
- attnlj∈A에 대해 계산된 attnlj′를 합산하여 단일 벡터인 안전 증폭 벡터 SV를 생성
- SV를 대상 모델 Mt의 특정 레이어 l의 은닉 상태 hl에 통합하여 업데이트된 은닉 상태 hl′ 및 업데이트된 은닉 상태를 가진 모델 Mt′을 구함. 여기서 γ는 하이퍼파라미터
hl′=hl+γ⋅SV
- Safety-Guided Decoding Strategy (sGDS) 단계
- 유해한 질문-답변 쌍으로 구성된 데이터셋 Dusf를 사용하여 동일한 LLM을 fine-tuning하고 유해 모델 Musf 구성
- Mt′의 출력 분포를 보존하면서 **Musf의 유해한 경향을 완화하기 위해 output probabilities 수정
- Union 연산자 사용하여 Mt′와 Musf의 output distribution을 통합하는 combined distribution C를 구함
- Union 연산자는 두 분포 중 하나라도 특정 토큰 x에 높은 확률이라면 결과 분포도 해당 토큰에 높은 확률을 반영하도록 비선형 결합. 여기서 I(x)는 인디케이터 함수
D[I1]KL(C∣∣Mt′)+D[I2]KL(C∣∣Musf)whereI1(x)=[Mt′(x)>Musf(x)]I2(x)=1−I1(x)
- KL-divergence로 C(x)를 구함. 여기서 σ는 standard softmax 함수
C(x)=σ(max(logMt′(x),logMusf(x)))
- Mt′의 유해성을 줄이기 위해, Mt′의 분포에서 특정 토큰들의 영향을 제한함으로써 안전한 출력 분포 Mts를 얻음. 여기서 λ는 하이퍼파라미터
Mts=Mt′−λ⋅σ(max(logMt′,logMusf))=Mt′−λ⋅C

Experimental Result
