Unlearn What You Want to Forget: Efficient Unlearning for LLMs(Chen&Yang, 2023, ACL)

์ดํœ˜์˜ยท2024๋…„ 3์›” 27์ผ

Unlearning Literature Review

๋ชฉ๋ก ๋ณด๊ธฐ
5/5
  • Unlearn What You Want to Forget: Efficient Unlearning for LLMs, Chen&Yang, 2023

    ๐Ÿ“– Title: Unlearn What You Want to Forget: Efficient Unlearning for LLMs

    ๐Ÿ—“ Year: 2023

    ๐Ÿ› Publish: ACL

    ๐Ÿ‘ค Author: Chen & Yang

    ๐Ÿ”—ย Link: https://arxiv.org/abs/2310.20150

    ๐Ÿ“ Summary:

    • ๋ณธ ๋…ผ๋ฌธ์€ Language model์— ๋Œ€ํ•œ Target data influence removal ๋ชฉ์ ์˜ unlearning ๋ฌธ์ œ๋ฅผ ๋‹ค๋ฃจ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

    • ๋ณธ ๋…ผ๋ฌธ์—์„œ ์ œ์•ˆํ•˜๋Š” EUL์˜ ํ•ต์‹ฌ ๋ชฉํ‘œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์š”์•ฝ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

      • LLM์˜ multi-task nature, ๊ทธ๋ฆฌ๊ณ  LLM์€ ๊ฑฐ๋Œ€ํ•œ ํฌ๊ธฐ์˜ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด ํ•™์Šต๋˜๋ฏ€๋กœ task๋งˆ๋‹ค, forgetting target๋งˆ๋‹ค unlearning์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์€ inefficient
      • ๋”ฐ๋ผ์„œ ๋งค๋ฒˆ ๋‹ฌ๋ผ์ง€๋Š” task์™€ forgetting target์— ๋Œ€ํ•ด ์ƒˆ๋กœ unlearnํ•  ํ•„์š” ์—†์ด, Original LLM์€ ์œ ์ง€ํ•œ ์ฑ„๋กœ ์ž‘์€ ํฌ๊ธฐ์˜ โ€œunlearning layerโ€๋ฅผ pluggingํ•˜๋Š” ๊ฒƒ๋งŒ์œผ๋กœ ๋‹ค์–‘ํ•œ task์™€ target์— ๋Œ€ํ•œ forget model๋กœ switchํ•  ์ˆ˜ ์žˆ๋‹ค.
      • ์ด๋Ÿฌํ•œ โ€œunlearning layerโ€๋Š” ๊ฐ task์™€ target์— ๋Œ€ํ•ด, ์ œ์•ˆ๋œ objective๋ฅผ ์ตœ์ ํ™” ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ํ•™์Šตํ•˜์—ฌ ๊ตฌํ•  ์ˆ˜ ์žˆ๋‹ค. (์ด ๋•Œ, LLM์€ freeze)
      • Multiple target์— ๋Œ€ํ•œ โ€œunlearning layerโ€๋ฅผ ๊ตฌํ•˜๊ณ  ์‹ถ์€ ๊ฒฝ์šฐ, ๊ฐœ๋ณ„ target์— ๋Œ€ํ•œ โ€œunlearning layerโ€๋ฅผ proposed method์— ๋”ฐ๋ผ fusionํ•˜์—ฌ ๊ตฌํ•  ์ˆ˜ ์žˆ๋‹ค.
    • ๋ณธ ๋…ผ๋ฌธ์—์„œ ์ œ์•ˆํ•˜๋Š” EUL์˜ scheme์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

      • Unlearning layer๋Š” FF network ๋’ค์— plug๋˜๋ฉฐ, ๋…ผ๋ฌธ ์ƒ์—์„œ๋Š” adapter๋ผ๊ณ ๋งŒ ํ‘œํ˜„๋˜์–ด ์žˆ์–ด์„œ ์ •ํ™•ํ•œ ์•„ํ‚คํ…์ฒ˜๋Š” ์ฝ”๋“œ๋ฅผ ์‚ดํŽด๋ด์•ผ ํ•  ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.
      • ์ถ”๊ฐ€๋กœ, LLM์— ์กด์žฌํ•˜๋Š” ๋ชจ๋“  transformer layer์— ์‚ฝ์ž…๋˜๋Š” ๊ฒƒ์ธ์ง€๋„ ๋”ฐ๋กœ ๋‚˜์™€์žˆ์ง€ ์•Š์•„์„œ ์ฝ”๋“œ๋ฅผ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
    • ๋‹ค์Œ์€ unlearning layer๋ฅผ ํ•™์Šตํ•˜๋Š” objective์— ๋Œ€ํ•œ ๋‚ด์šฉ์ž…๋‹ˆ๋‹ค.

      LEUL=LKL+ฮปLTASK+ฮณLLM{L_{EUL}}=L_{KL}+\lambda L_{TASK}+\gamma L_{LM}

      ์ด ๋•Œ ๊ฐ loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. (F(.)๋Š” orignial model, F(f(.))๋Š” updated model)

      1. KL Loss

        LKL=ฮฑโˆ‘xrKL(F(xr)โˆฅF(f(xr)))โˆ’โˆ‘xfKL(F(xf)โˆฅF(f(xf)))L_{KL} = \alpha \sum_{x^r} KL(F(x^r) \parallel F(f(x^r))) - \sum_{x^f} KL(F(x^f) \parallel F(f(x^f)))
      • KL Loss๋Š” Original ๋ชจ๋ธ์— ๋Œ€ํ•œ output ๋ถ„ํฌ์™€ Unlearned Model(plugged model)์— ๋Œ€ํ•œ output ๋ถ„ํฌ์˜ KL-Divergence๊ฐ€, retain set์˜ ๊ฒฝ์šฐ ์ž‘๋„๋ก(๊ฐ€๊น๋„๋ก) forget set์˜ ๊ฒฝ์šฐ ํฌ๋„๋ก(๋ฉ€๋„๋ก) teacher-student manner๋กœ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.
      1. Task loss

        LTASK=โˆ‘xrโ„“(F(f(xr)),yr)L_{\text{TASK}} = \sum_{x^r} \ell(F(f(x^r)), y^r)
      • retain set์— ๋Œ€ํ•ด ์ •์˜๋œ task์— ๋Œ€ํ•œ ์„ฑ๋Šฅ ์œ ์ง€์— ๋Œ€ํ•œ loss ์ž…๋‹ˆ๋‹ค.
      1. LM loss

        LLM=โˆ’โˆ‘xfl(F(f(xf)))L_{LM} = - \sum_{x^f} l(F(f(x^f)))
      • ll์€ F(.)๋ฅผ pretraining ํ•  ๋•Œ ์‚ฌ์šฉํ•˜์˜€๋˜ loss์ž…๋‹ˆ๋‹ค. ์ด๋ฅผํ…Œ๋ฉด masked language model์˜ ๊ฒฝ์šฐ โˆ’logP(x^โˆฃxโˆ’x^)-logP(\hat x \vert x-\hat x)
      • forgetting target์— ๋Œ€ํ•œ task ์„ฑ๋Šฅ์„ ๋‚ฎ์ถ”๋Š” ๊ฒƒ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, target์ด ๊ทธ ์–ด๋– ํ•œ ๋‹ต๋ณ€ generation์ค‘์—๋„ ํฌํ•จ๋˜์ง€ ์•Š๋„๋ก ์œ ๋„ํ•˜๊ธฐ ์œ„ํ•œ loss๋กœ ์ดํ•ดํ•˜์˜€์Šต๋‹ˆ๋‹ค.
    • ๋‹ค์Œ์€ fusion mechanism ์ž…๋‹ˆ๋‹ค.

      • Fusion mechanism์˜ ๋ชฉ์ ์€ ์„œ๋กœ ๋‹ค๋ฅธ unlearning layer WiW_i๋ฅผ ์œ„์™€ ๊ฐ™์€ ๋ฐฉ๋ฒ•์œผ๋กœ ๊ตฌํ•˜์˜€์„ ๋•Œ, ์ด๋ฅผ ๋‹จ์ผํ•œ unlearning layer WmW_m์œผ๋กœ mergeํ•˜๊ธฐ ์œ„ํ•จ์ž…๋‹ˆ๋‹ค.

        minโกWmโˆ‘iโˆฅWmTxifโˆ’WiTxifโˆฅ2\min_{W_m} \sum_i \left\| W_m^T x_i^f - W_i^T x_i^f \right\|^2
      • ์œ„์˜ ์‹์€ linear regression problem์ด๋ฏ€๋กœ ๋‹ค์Œ๊ณผ ๊ฐ™์€ closed-form solution์„ ๊ฐ€์ง‘๋‹ˆ๋‹ค.

        Wm=(โˆ‘iXifTXif)โˆ’1โˆ‘i(XifTXifWi)W_m = \left( \sum_i X_i^{f^T} X_i^f \right)^{-1} \sum_i \left( X_i^{f^T} X_i^f W_i \right)
      • ์œ„์™€ ๊ฐ™์ด ๊ตฌํ•œ WmW_m์„ pluggingํ•จ์œผ๋กœ์จ mutlple task ๋˜๋Š” target์— ๋Œ€ํ•œ unlearned model์„ ๊ตฌํ•ด๋‚ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

    • ๋‹ค์Œ์€ ์‹คํ—˜์ž…๋‹ˆ๋‹ค. ์‹คํ—˜์€ T5(base, 3B) ๋ชจ๋ธ์— ๋Œ€ํ•ด IMDB๋ฅผ ์ด์šฉํ•œ Sentiment classification, SAMSum์„ ์ด์šฉํ•œ summary generation ๋‘ ๊ฐ€์ง€์˜ task์— ๋Œ€ํ•ด ์ˆ˜ํ–‰๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

    • Baseline์€ Retrain, FT, SISA, Reverse-Gradient(GA)๋ฅผ ๋‘๊ณ  Forget set acc, Retain set acc, Test acc, MLM Loss, RTE๋กœ ํ‰๊ฐ€ํ•˜์—ฌ ๋น„๊ตํ•˜์˜€์Šต๋‹ˆ๋‹ค.

    • MLM loss๋Š” forget data, ๋˜๋Š” ๊ด€๋ จ๋œ entity์™€ action์„ mask ํ† ํฐ์œผ๋กœ ์ฒ˜๋ฆฌ ํ›„, โ€œPredict the masked wordโ€๋ผ๋Š” ํ…œํ”Œ๋ฆฟ์„ ์ ์šฉํ•˜์—ฌ, ๋ชจ๋ธ๋กœ๋ถ€ํ„ฐ forget data๋ฅผ ์–ผ๋งˆ๋‚˜ ์ž˜ ์ถ”์ถœํ•ด๋‚ผ ์ˆ˜ ์žˆ๋Š”์ง€๋ฅผ ํ‰๊ฐ€ํ•˜๋Š” ์ง€ํ‘œ์ž…๋‹ˆ๋‹ค. ๋งˆ์Šคํ‚นํ•  ๋Œ€์ƒ์€ AllenNLP๋ผ๋Š” pretrained NER ๋ชจ๋ธ์„ ํ†ตํ•ด ์ถ”์ถœํ•˜์˜€์Šต๋‹ˆ๋‹ค.

    • ์‹คํ—˜ ๊ฒฐ๊ณผ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

    • IMDB / T5-base

    • SAMSum / T5-base

1๊ฐœ์˜ ๋Œ“๊ธ€

comment-user-thumbnail
2024๋…„ 8์›” 20์ผ

ํผ๊ฐ€์š” ~

๋‹ต๊ธ€ ๋‹ฌ๊ธฐ