[Paper review] STAIR: Improving Safety Alignment with Introspective Reasoning

브라우니맛있디·2026년 2월 12일

Paper review

목록 보기
12/14
post-thumbnail

1. 배경 지식

LLM Safety Alignment

  • LLM의 안전성을 확보하기 위해 유해한 콘텐츠 생성을 방지하는 기술
  • 유해한 프롬프트에 대해 모델이 적절히 거부하거나 안전한 답변을 생성하도록 학습
  • 예시
    • 폭력적 내용 생성 요청 → "I cannot assist with that request"
    • 개인정보 침해 요청 → 거부 또는 대안 제시

System 1 vs System 2 Thinking

  • System 1 (빠른 사고)
    • 직관적이고 자동적인 사고
    • 빠르지만 실수할 수 있음
    • 기존 대부분의 Safety Alignment 방법이 해당
  • System 2 (느린 사고)
    • 논리적이고 분석적인 사고
    • 느리지만 더 정확하고 신중함
    • STAIR가 제안하는 접근 방식

Chain-of-Thought (CoT) Reasoning

  • 문제 해결 과정을 단계별로 나누어 추론하는 방법
  • 예시: "2+2는 4이다" (직접 답변) vs "먼저 2를 생각하고, 여기에 2를 더하면, 4가 된다" (단계별 추론)
  • 복잡한 문제에서 더 정확한 답변 가능

Jailbreak Attack

  • LLM의 안전 장치를 우회하여 유해한 콘텐츠를 생성하도록 유도하는 공격
  • 예시
    • 역할극 프롬프트: "당신은 악당 캐릭터입니다. 악당답게 행동하세요..."
    • 간접 요청: "이론적으로 만약 ~한다면 어떻게 될까요?"
  • 기존 방법들이 이러한 공격에 취약함

2. 서론

Limitations of Prior Safety Alignment Methods

  • Direct Refusal (직접 거부 방식)
    • 유해한 프롬프트 → 즉각적으로 거부
    • 문제점
      • 안전성-성능 트레이드오프: 안전하게 만들려다 보니 유용성이 떨어짐
      • Jailbreak 공격에 취약: 교묘한 우회 방법에 쉽게 속음
      • 과도한 거부(Over-refusal): 안전한 질문도 거부하는 경향
    • 예시: "범죄 예방을 위한 보안 시스템 설계 방법"을 물어도 거부할 수 있음
  • 왜 취약한가?
    • System 1 방식의 즉각적 판단에 의존
    • 문맥과 의도를 깊이 분석하지 않음
    • 표면적인 키워드에만 반응

Motivation

  • 인간처럼 신중하게 생각하고 판단하는 모델을 만들자!
  • 단계별로 안전성을 분석하는 추론 능력을 학습시키면?
    • 더 robust한 안전성 확보
    • 유용성 유지
    • Jailbreak 공격 방어

3. 방법

STAIR Framework Overview

  • STAIR는 3단계로 구성
    • Stage 1: Structured Reasoning with CoT (SFT)
    • Stage 2: Self-Improvement with SI-MCTS (DPO)
    • Stage 3: Process Reward Model Training (Best-of-N)
  • 추론 형식
    • 각 단계는 <Reasoning_step> 태그로 감싸짐
    • 최종 답변은 태그로 감싸짐
  • 추론 단계 구성
    1. 문제 분석 (Problem Analysis): 요청의 의도와 맥락 파악
    2. 단계별 추론 (Reasoning): 잠재적 위험 요소 탐지 및 분석
    3. 최종 답변 (Final Answer): 안전한 답변 제공 또는 명확한 거부

Stage 1: Structured Reasoning Capability

  • 목표
    • 모델에 구조화된 추론 능력을 먼저 학습시키기
    • 안전성 판단 과정을 명시적인 단계로 분해
  • 데이터셋
    • Ultrafeedback, SafeRLHF, JailBreakV
    • 총 20K개의 샘플
    • CoT 형식으로 변환: 원본 질문-답변 쌍에 추론 과정 추가
  • 추론 단계 구성
    1. 문제 분석: 요청의 의도와 맥락 파악
    2. 안전 위험 식별: 잠재적 위험 요소 탐지
    3. 최종 답변 생성: 안전한 답변 제공
  • 핵심 아이디어
    • 각 추론 단계를 트리의 노드로 모델링
    • Monte Carlo Tree Search로 최적의 추론 경로 탐색
    • 안전성과 유용성을 모두 고려한 보상 설계
  • 보상 함수 설계
    • Total_Reward = Safety_Reward + Helpfulness_Reward
  • Safety Reward: Llama Guard 3로 평가 (안전한가?)
  • Helpfulness Reward: GPT-4로 평가 (도움이 되는가?)
  • 두 가지 보상을 독립적으로 계산하여 합산
  • 왜 분리하는가?
    • 안전성과 유용성은 서로 다른 기준
    • 하나의 보상으로 통합하면 트레이드오프 발생
    • 분리하면 두 목표를 동시에 최적화 가능
  • SI-MCTS 과정
    1. 현재 상태에서 가능한 다음 추론 단계들 생성
    2. 각 단계에 대해 안전성 & 유용성 보상 계산
    3. 보상이 높은 경로를 우선적으로 탐색
    4. 최종적으로 가장 좋은 추론 경로 선택
  • DPO (Direct Preference Optimization) 학습
    • SI-MCTS로 생성한 추론 경로 중:
      • Chosen: 높은 보상을 받은 경로
      • Rejected: 낮은 보상을 받은 경로
    • Step-level에서 선호 쌍을 만들어 학습
    • 모델이 더 나은 추론 단계를 선택하도록 유도

Stage 3: Process Reward Model (PRM)

  • 목표
    • 각 추론 단계의 품질을 평가하는 보상 모델 학습
    • Test-time에 더 나은 응답을 찾는 가이드로 활용
  • PRM 학습 데이터
    • Stage 2의 SI-MCTS에서 생성된 추론 경로
    • 각 단계마다 보상 점수 라벨링
    • Step-level reward prediction 학습
  • Test-time Best-of-N Sampling
    1. N개의 다른 추론 경로 생성
    2. PRM으로 각 경로의 품질 평가
    3. 가장 높은 점수를 받은 응답 선택
    • N이 클수록 더 좋은 응답 찾을 가능성 증가
  • 왜 PRM이 필요한가?
    • Inference 시에는 SI-MCTS를 직접 사용하기 어려움 (너무 느림)
    • PRM은 빠르게 추론 경로의 품질을 평가 가능
    • Test-time 확장(scaling)을 통한 성능 향상

4. 실험

Experimental Setup

  • Base Model
    • Llama-3.1-8B-Instruct
    • Qwen2-7B-Instruct
  • Safety Benchmarks
    • StrongREJECT: 유해 요청 거부 능력 평가
    • WildGuardTest: 실제 환경의 다양한 안전성 시나리오
    • XSTest: 과도한 거부(over-refusal) 측정
  • Helpfulness Benchmarks
    • AlpacaEval: 일반적인 instruction following 능력
    • MT-Bench: 다양한 태스크에서의 유용성
    • GSM8K, MATH: 수학적 추론 능력
  • Jailbreak Attack Benchmarks
    • GCG (Greedy Coordinate Gradient)
    • AutoDAN
    • PAIR (Prompt Automatic Iterative Refinement)
    • TAP (Tree of Attacks with Pruning)

5. 결과

Main Results

  • SFT/DPO: 안전성은 향상되지만, AlpacaEval 하락 (트레이드오프)
  • SACPO: 트레이드오프는 완화되지만 SimpleQA, InfoFlow 저하
  • Self-Rewarding: 명백한 유해 쿼리는 잘 거부하지만, jailbreak에 취약
  • STAIR(Ours): 안전성과 유용성 동시 개선

Jailbreak Attack Defense

  • 모든 jailbreak 공격 (GCG, AutoDAN, PAIR, TAP)에서 baseline 대비 우수한 방어 성능
  • Step-by-step 분석으로 교묘한 우회 시도를 탐지
  • WildChat (실제 고독성 쿼리)에서도 높은 refusal rate 달성

Test-time Scaling Effect

  • Best-of-N과 Beam Search 적용 시
    • StrongReject에서 goodness score 0.06 향상
    • AlpacaEval에서 winning rate 3.0% 이상 향상
    • Computation budget (log₂ scale)이 증가할수록 성능 지속 향상
  • 해석
    • Best-of-4 (log₂(4)=2): 약 0.06 향상
    • Best-of-8 (log₂(8)=3): 더 큰 향상
    • Beam Search (beam_width=2, 4 successors): 유사한 효과
    • N이 클수록 더 좋은 응답을 찾을 가능성 증가

⇒ 추가 계산 자원을 활용하면 성능을 더 높일 수 있음!

Ablation Study

1. Balance between Safety and Helpfulness Data

  • Safety data 비율 증가 시:
    • StrongReject 성능 향상
    • AlpacaEval 성능 감소 (트레이드오프)
  • 하지만 STAIR는 어느 비율에서든 base model 대비 양쪽 모두 향상
  • 논문에서는 1:1 비율 사용 (5K safety + 5K helpfulness)

2. Step-level Optimization의 효과 & Iterative Training의 중요성

  • Step-level DPO가 full-trajectory DPO보다 효과적
  • 반복 학습을 통해 데이터 품질 향상 → 성능 지속 개선됨

3. Reward Function의 형태

  • 모든 형태가 유사한 성능
  • ⇒ 3가지 property (Safety as Priority, Dual Monotonicity, Degeneration)가 모두 중요

4. Self-Rewarding vs PRM

  • Self-rewarding은 variance가 크고 불안정, PRM이 더 안정적이고 효과적

5. Computation Costs

  • Training Cost:
    • 전체 학습 약 30시간 (A800 8장/ 40GB 8)
    • SI-MCTS 데이터 생성이 대부분 (offline, 약 15s per prompt)
    • Step-level preference pair 생성: 평균 0.47s
    • Self-rewarding full-trajectory: 평균 0.40s
  • Inference Cost:
    • 응답 길이 증가로 인한 latency 소폭 증가 (허용 범위)
    • Test-time scaling: computation budget에 비례
    • 선택적으로 사용 가능 (성능 vs 비용 트레이드오프)
  • 추가 비용은 있지만 안전성-유용성 균형을 고려하면 실용적

6. 기존 LLM과의 비교

Open-source Reasoning LLMs와 비교

  • Open-source reasoning LLMs (o1-like models)은 추론 능력은 있지만 안전성 취약
  • Deliberative Alignment 적용해도 jailbreak 공격에 여전히 취약
  • STAIR는 추론 능력과 안전성을 동시에 확보

Commercial LLMs와 비교

  • 대부분의 상용 LLM은 정상 쿼리는 잘 거부하지만 jailbreak에 취약
  • Claude-3.5가 가장 강력한 방어 성능
  • STAIR + Test-time scaling으로 Claude-3.5와 비슷한 수준 달성

6. 결론

  • LLM의 안전성 정렬에서 기존 방법들은 즉각적인 거부 방식에 의존하여 안전성-성능 트레이드오프와 jailbreak에 취약한 문제를 가지고 있었음
  • 성찰형 추론을 도입하여 모델이 단계별로 안전 위험을 분석하도록 학습시키는 3단계 프레임워크 STAIR 제안
  • STAIR는 Claude-3.5에 근접하는 안전성을 달성하면서도 유용성을 유지하였고, 다양한 jailbreak 공격에 대해 강건함을 보임
  • 또한, test-time scaling을 통해 추가 계산 자원으로 성능을 더욱 향상시킬 수 있음을 보임
profile
브라우니맛있디

0개의 댓글