ShieldHead: Decoding-time Safeguard for Large Language Model

Yuri·2025년 10월 9일

논문 리뷰

목록 보기
22/23

(ACL 2025, Accept)

Introduction

  • 기존 LLM 기반 가드레일(e.g. LlamaGuard)은 입출력의 위험을 식별하는 데 유망하지만, 가드레일 또한 하나의 LLM인 만큼 추가 추론 단계가 필요하여 계산 비용이 크게 증가하고 효율성이 떨어짐
  • 한편 저자의 관찰에 따르면 LLM의 Hidden State는 유해한 콘텐츠를 식별하는 데 필요한 충분한 Context를 인코딩하고 있음
  • 이에 백본 LLM last hidden state에 경량 classification head를 훈련시키는 ShieldHead라는 새로운 decoding-time safeguard 프레임워크 제안
  • ShieldHead는 Next-token-prediction LM Head와 병렬로 작동하는 보조 Branch 역할을 하며, 과거 텍스트 시퀀스에서 잠재적 위험을 감지

Methnology

  1. ShieldHead for Safety Classification
    • 입력 문장 jj의 토큰 ii에 대해 백본 LLM H\text{H}가 출력하는 hidden state를 h(i,j)Rd\text{h}^{(i,j)} \in \mathbb{R}^d라고 할 때, ShieldHead는 MultiLayer Perceptron (MLP)으로 구성된 Classifier임
    • ShieldHead는 다음과 같이 토큰 수준 안전 확률 ptoken(i,j)RC\mathbf{p}_{\text{token}}^{(i,j)} \in \mathbb{R}^C을 예측:
      ptoken(i,j)=Softmax(ShieldHead(h(i,j)))\mathbf{p}_{\text{token}}^{(i,j)} = \text{Softmax}(\text{ShieldHead}(\mathbf{h}^{(i,j)}))
    • 여기서 CC는 분류될 카테고리 수(본 연구에서는 안전/위험의 이진 분류를 위해 C=2C=2), ptoken(i,j,c)R1\mathbf{p}_{\text{token}}^{(i,j,c)} \in \mathbb{R}^1h(i,j)\mathbf{h}^{(i,j)}가 카테고리 cc로 분류될 확률
    • 이 방식은 백본 모델의 decoding 과정에서 multi-task head를 추가하여 다음 토큰이 decoding됨과 동시에 과거 시퀀스의 위험 분류 결과를 예측하여 실시간 필터링이 가능하도록 함
  2. Label Disambiguation
    • ShieldHead 훈련에는 토큰 수준 안전 데이터가 필수적이지만 이는 높은 라벨링 비용을 요구하므로, 각 클래스에 대한 Prototype을 포함하는 Label Disambiguation 모듈을 도입
      1. 문장별로 문장 내 모든 토큰에 문장 수준 레이블을 Y(i,j)tY^t_{(i,j)}에 초기 레이블로 할당 (e.g. 문장이 unsafe이면 모든 토큰도 unsafe)

      2. 각 클래스에 대한 프로토타입 P=[P1,...,PC]P = [P_1, ..., P_C]를 0 벡터로 초기화

      3. ShieldHead가 각 토큰의 last hidden state h(i,j)h_{(i,j)}를 받아 해당 토큰의 클래스 확률 예측하고, 각 클래스 cc에 대해 가장 높은 예측 확률을 보인 상위 KK개 토큰들의 hidden state TctT^t_c 식별

      4. TctT^t_c를 사용하여 moving-average를 통해 해당 클래스의 프로토타입 PctP^t_c 업데이트

        Pct+1=Normalize(γPct+(1γ)h(i,j)),for h(i,j)TctP^{t+1}_c = \text{Normalize}(\gamma \cdot P^t_c + (1 - \gamma) \cdot h(i,j)), \quad \text{for } h(i,j) \in T^t_c

        여기서 γ\gamma는 업데이트 속도를 조절하는 계수로, 훈련 진행에 따라 γ=0.99\gamma = 0.99에서 γ=0.95\gamma = 0.95로 점차 감소

        이 단계를 통해 각 클래스의 프로토타입 PcP_c는 ShieldHead가 그 클래스에 속할 것이라고 가장 강하게 믿는 토큰들의 평균적인 특징 벡터로 점진적으로 수렴해 나감

      5. PPh(i,j)h_{(i,j)} 사이의 proximity를 기반으로, ii번째 토큰이 각 클래스에 속할 확률을 나타내는 프로토타입 레이블 점수 s(i,j)s_{(i,j)} 계산

        s(i,j)=Softmax(Ph(i,j))s(i,j) = \text{Softmax}(P \cdot h(i,j))
      6. 기존의 토큰 수준 소프트 레이블 Y(i,j)tY^t_{(i,j)}s(i,j)s_{(i,j)}를 사용하여 다시 moving-average로 업데이트

        Y(i,j)t+1=σY(i,j)t+(1σ)s(i,j)Y^{t+1}_{(i,j)} = \sigma \cdot Y^t_{(i,j)} + (1 - \sigma) \cdot s(i,j)

        여기서 σ\sigma도 이동 평균 업데이트 속도를 조절하는 계수로, 훈련 진행에 따라 σ=0.98\sigma = 0.98에서 σ=0.5\sigma = 0.5로 감소

  3. Loss Function
    • 토큰 수준 학습에서는 Cross-entropy와 유사한 손실 함수 사용
      Ltoken=j=1Ni=1Mc=1CY(i,j,c)tlog(ptoken(i,j,c))\mathcal{L}_{\text{token}} = - \sum_{j=1}^N \sum_{i=1}^M \sum_{c=1}^C \mathbf{Y}^t_{(i,j,c)} \cdot \log(\mathbf{p}_{\text{token}}^{(i,j,c)})
      여기서 Y(i,j,c)t\mathbf{Y}^t_{(i,j,c)}는 시간 tt에 업데이트된 Soft Label의 cc-번째 카테고리이며, ptoken(i,j,c)\mathbf{p}_{\text{token}}^{(i,j,c)}jj-번째 문장의 ii-번째 토큰에 대한 예측 확률의 cc-번째 카테고리
    • 문장 수준 학습에서는 Prompt와 Response의 마지막 토큰 예측 결과를 사용하여 각각 Lprompt\mathcal{L}_{\text{prompt}}Lres\mathcal{L}_{\text{res}} Cross-entropy 손실 계산
    • 전체 손실 함수는 다음과 같음
      L=Lprompt+Lres+λLtoken\mathcal{L} = \mathcal{L}_{\text{prompt}} + \mathcal{L}_{\text{res}} + \lambda \mathcal{L}_{\text{token}}
      여기서 λ\lambda는 문장 수준 손실과 토큰 수준 손실 간의 상대적 기여도를 조절하는 Hyperparameter임

image.png

Experimental Result

image.pngimage.png

0개의 댓글