๐Ÿ“ Week 5: NLP Decoding Strategy

oceannยท2024๋…„ 9์›” 5์ผ
0

๐Ÿ’ป Naver Boostcamp AI Tech 7๊ธฐ NLP

๋ชฉ๋ก ๋ณด๊ธฐ
5/5
post-thumbnail

์ธ๋„ค์ผ์€ Decoding Strategy์˜ ํŠน์„ฑ์„ ๋ฐ˜์˜ํ•ด ์˜ˆ์ƒ ์‘๋‹ต์„ ์ƒ์„ฑํ•œ ๊ฒฐ๊ณผ์ด๋‹ค.
Beam Search๋Š” ์ฃผ๋กœ ๋งŽ์ด ์‚ฌ์šฉ๋˜๋Š” ๋‹จ์–ด๋ฅผ ์กฐํ•ฉํ•˜์—ฌ ์งง์€ ๋ฌธ์žฅ์„ ์ƒ์„ฑํ•ด๋‚ด๊ณ , Top-k Sampling์€ ๋ฌด์ž‘์œ„์„ฑ์ด ๊ฐ•ํ•ด ์˜ˆ์ƒ ๋ฐ–์˜ ๋‹ต๋ณ€์ด ๋“ฑ์žฅํ•˜๊ธฐ๋„ ํ•œ๋‹ค. Top-p Sampling์€ ์ž์—ฐ์Šค๋Ÿฌ์šฐ๋ฉด์„œ๋„ ๋ณด๋‹ค ๋‹ค์–‘ํ•œ ๋‹จ์–ด๋ฅผ ์„ ํƒํ•˜์—ฌ ์œ ์—ฐํ•œ ์‘๋‹ต์ด ๊ฐ€๋Šฅํ•˜๋‹ค.

๐Ÿ’ก
4-5์ฃผ์ฐจ๋Š” ์ด์–ด์„œ ์ง„ํ–‰์ด ๋˜์—ˆ๋Š”๋ฐ, ์ €๋ฒˆ ์ฃผ์— ๊ฐ•์˜๋ฅผ ๋งŽ์ด ๋“ค์–ด๋†”์„œ ์ด๋ฒˆ ์ฃผ๋Š” ์ƒ๋Œ€์ ์œผ๋กœ ์—ฌ์œ ๊ฐ€ ์žˆ์—ˆ๋‹ค. ์—ฌ์œ ๊ฐ€ ์žˆ๋Š” ๊น€์— ์ด๋ ‡๊ฒŒ ๋ชฉ์š”์ผ์— ์˜ฌ๋ฆฌ๊ธฐ๋„ ํ•˜๊ณ ! ๋‚จ์€ ๋‚  ๋™์•ˆ์—๋Š” MLOps ๊ณต๋ถ€๋ฅผ ํ•  ๊ฒƒ์ด๋‹ค.
NLP ์ด๋ก ์„ ๋ฐฐ์šฐ๊ณ  ์ด์ œ ๊ณง ํ”„๋กœ์ ํŠธ๊ฐ€ ์‹œ์ž‘ํ•  ํ…๋ฐ, ๋‘๊ทผ๋‘๊ทผํ•˜๋‹ค. ๊ทธ๋™์•ˆ ๊ฒฐ๊ณผ๋งŒ์„ ๋ณด๊ณ  ํ”„๋กœ์ ํŠธ๋ฅผ ํ•ด์™”๋Š”๋ฐ ์ด๋ฒˆ์—๋Š” ๊ณผ์ •์— ์žˆ์–ด์„œ "์ž˜"์„ ์‹ค์ฒœํ•˜๊ณ ์ž ํ•œ๋‹ค.

5์ฃผ์ฐจ ๋‚ด์šฉ ์ •๋ฆฌ๋Š” BERT๋ฅผ ์“ธ๊นŒ ํ•˜๋‹ค๊ฐ€ ๋…ผ๋ฌธ ์Šคํ„ฐ๋””๋„ ํ–ˆ๊ฒ ๋‹ค, ๋…ผ๋ฌธ ๋ฆฌ๋ทฐ๋กœ ์˜ฌ๋ฆฌ๋ ค๊ณ  ํ•œ๋‹ค.
์—…๋Žƒํ•˜๋ฉด ๋งํฌ ๋‹ฌ์•„๋‘˜๊ฒŒ์š”.

NLP Decoding์ด๋ž€?

์ž์—ฐ์–ด ๋ชจ๋ธ์ด ์ตœ์ข… ์ถœ๋ ฅํ•˜๋Š” ํ˜•ํƒœ๊ฐ€ ์ž…๋ ฅ๊ณผ ๊ฐ™์€ sequence๋ผ๋ฉด ๊ทธ ๋•Œ ๊ทธ ๋•Œ ์–ด๋–ค ๋‹จ์–ด์˜ token์„ ์ถœ๋ ฅํ• ์ง€๋ฅผ ๊ฒฐ์ •ํ•˜๋Š” ๊ทผ๊ฑฐ๊ฐ€ ํ•„์š”ํ•˜๋‹ค.
๋ชจ๋ธ์ด ํ•™์Šตํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ๋“ค์€ ๋‹จ์–ด์˜ token๋“ค์— ๋Œ€ํ•œ ํ™•๋ฅ  ๋ถ„ํฌ์ผ ๊ฒƒ์ด๋‹ค. ๋”ฐ๋ผ์„œ, ๋ฌธ์žฅ XX๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ, ํ•ด๋‹น ํ™•๋ฅ  ๋ถ„ํฌ์—์„œ ํ™•๋ฅ ์ด ๊ฐ€์žฅ ๋†’์€ yy๋ฅผ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ์ด์ƒ์ ์ด๋‹ค.

argmaxyย logPLM(yโˆฃX)=argmaxy1,y2,...,ytย logPLM(y1,y2,...,ytโˆฃX)\underset{y}{\mathrm{argmax}} \text{ }\text{log}P_{LM}(y|X) = \underset{y_1, y_2, ..., y_t}{\mathrm{argmax}}\text{ }\text{log}P_{LM}(y_1, y_2, ..., y_t|X)

์ด๋•Œ ๋ชจ๋“  timestep t๋“ค์˜ ์กฐํ•ฉ์— ๋Œ€ํ•ด ์™„์ „ ํƒ์ƒ‰์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์€ ๋น„ํšจ์œจ์ ์ด๋ผ๋Š” ๊ฒƒ์€ ๋‹น์—ฐํ•˜๋‹ค. ๊ทธ๋ž˜์„œ ์œ„ ์‹์„ ๋ณ€ํ˜•ํ•œ๋‹ค.

PLM(yโˆฃX)=PLM(y1,y2,...,ytโˆฃX)P_{LM}(y|X) = P_{LM}(y_1, y_2, ..., y_t|X)
=PLM(y1โˆฃX)ร—PLM(y2โˆฃy1,X)ร—PLM(y3โˆฃy2,y1,X)ร—...ร—PLM(ytโˆฃy1,...,ytโˆ’1,X)= P_{LM}(y_1|X) \times P_{LM}(y_2|y_1, X) \times P_{LM}(y_3|y_2, y_1, X) \times ... \times P_{LM}(y_t|y_1, ..., y_{t-1}, X)
=โˆi=1tPLM(yiโˆฃy1,...,yiโˆ’1,X)= \overset{t}{\underset{i=1}{\prod}}P_{LM}(y_i|y_1, ..., y_{i-1}, X)

์œ„ ์‹์€ ์ด์ „ ๋‹จ๊ณ„์˜ ์ถœ๋ ฅ์„ ํ™œ์šฉํ•˜์—ฌ ํ˜„์žฌ ๋‹จ๊ณ„์˜ ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•˜๋Š” Auto-Regressiveํ•œ ๊ณผ์ •์ด๋‹ค.

Greedy Decoding

์ฒซ ๋ฒˆ์งธ ๋ฐฉ๋ฒ•์œผ๋กœ Greedy Decoding์ด ์žˆ๋‹ค. ๋งค timestep์—์„œ์˜ ์ตœ๋Œ€ ํ™•๋ฅ ์„ ๊ฐ–๋Š” token์„ ์„ ํƒํ•œ๋‹ค. ํ•˜์ง€๋งŒ ์ด๋Š” ํ™•๋ฅ  ๋ถ„ํฌ์— ๊ทผ์‚ฌ๋ฅผ ํ•˜๋Š” ๋ฐฉ๋ฒ•์ผ ๋ฟ, ์„ ํƒ๋œ token์ด ์ •๋‹ต์ด ์•„๋‹ ๊ฒฝ์šฐ ๋‹ค์Œ ๋‹จ๊ณ„์—๊นŒ์ง€ ์˜ํ–ฅ์„ ๋ฏธ์ณ ์ตœ์ ํ•ด๋ฅผ ๋ณด์žฅํ•˜์ง€ ์•Š๋Š”๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ๋‹ค.

Greedy Decoding ๋ฐฉ์‹์„ ์ ์šฉํ•œ ๊ฒฐ๊ณผ์ตœ์ ํ•ด์˜ ๊ฒฐ๊ณผ

์ข…๋ฃŒ ์กฐ๊ฑด
<END> token์„ ์„ ํƒํ•  ๋•Œ๊นŒ์ง€ ๊ณ„์† ์ƒ์„ฑํ•œ๋‹ค.

Greedy Decoding์˜ ๋‹จ์ ์„ ๋ณด์™„ํ•˜๊ณ ์ž ๋งค timestep๋งˆ๋‹ค k๊ฐœ์˜ ํ›„๋ณด token์„ ์ƒ์„ฑํ•œ๋‹ค. ์„ ํƒ๋˜๋Š” token์˜ ์ ์ˆ˜๋Š” ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

score(y1,...,yt)=logPLM(y1,...,ytโˆฃX)=โˆ‘i=1tlogPLM(yiโˆฃy1,...,yiโˆ’1,X)\text{score}(y_1, ..., y_t) = \text{log}P_{LM}(y_1, ..., y_t|X) = \overset{t}{\underset{i=1}{\sum}}\text{log}P_{LM}(y_i|y_1, ..., y_{i-1}, X)

๊ธฐ์กด์˜ โˆi=1tPLM(yiโˆฃy1,...,yiโˆ’1,X)\overset{t}{\underset{i=1}{\prod}}P_{LM}(y_i|y_1, ..., y_{i-1}, X)์— log\text{log}๋ฅผ ๋ถ™์—ฌ ๋ง์…ˆ์œผ๋กœ ๋ณ€ํ™˜ํ•œ ๊ฒƒ์ด๋‹ค. ํ™•๋ฅ ์˜ ํŠน์„ฑ์— ๋”ฐ๋ผ P\text{P}๋Š” 0-1 ์‚ฌ์ด์˜ ๊ฐ’์„ ๊ฐ–๊ณ , ์ด๋ฅผ log\text{log} ํ•จ์ˆ˜์— ์ž…๋ ฅํ•˜๋ฉด (โˆ’โˆž,0](-\infin, 0]์˜ ๋ฒ”์œ„๋ฅผ ๊ฐ–๊ธฐ ๋•Œ๋ฌธ์— score๊ฐ€ ํด์ˆ˜๋ก ์ฆ‰, 0์— ๊ฐ€๊นŒ์šธ์ˆ˜๋ก ๋” ์ ์ ˆํ•œ token์œผ๋กœ ํŒ๋‹จํ•œ๋‹ค. ์•„๋ž˜๋Š” k=2์ผ ๋•Œ์˜ ์˜ˆ์‹œ์ด๋‹ค.

1 ์‹œ์ž‘ token์œผ๋กœ๋ถ€ํ„ฐ ํ™•๋ฅ ์ด ๊ฐ€์žฅ ๋†’์€ k๊ฐœ token์— ๋Œ€ํ•œ score๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.
2 k๊ฐœ์˜ ํ›„๋ณด ๊ฐ๊ฐ์œผ๋กœ๋ถ€ํ„ฐ score๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ ์ƒ์œ„ k๊ฐœ์˜ token์„ ์„ ํƒํ•œ๋‹ค. ์ด ๋‹จ๊ณ„์—์„œ ์ด k2k^2๊ฐœ์˜ token์ด ์ƒ์„ฑ๋œ๋‹ค. ์ด๋•Œ ๋˜ score๋ฅผ ๋น„๊ตํ•˜์—ฌ ์ƒ์œ„ k๊ฐœ์˜ token์„ ์„ ํƒํ•œ๋‹ค.
3 2๋ฒˆ์˜ ๊ณผ์ •์„ ๋ฐ˜๋ณตํ•œ๋‹ค. ๊ทธ๋ฆผ์—์„œ๋Š” 1๋ฒˆ์—์„œ ์ƒ์„ฑํ•œ ํ•œ ์ชฝ์˜ ๋ธŒ๋žœ์น˜๋งŒ ๋”ฐ๋ผ๊ฐ€๋Š” ํ˜•ํƒœ๋กœ ์„ ํƒ๋˜์—ˆ๋Š”๋ฐ, ์ƒ๊ด€ ์—†๋‹ค.
4 ์œ„ ๊ณผ์ •์„ ๋ฐ˜๋ณตํ•œ๋‹ค.

์ข…๋ฃŒ ์กฐ๊ฑด
๊ฐ ๋‹จ๊ณ„๋ฅผ Beam์ด๋ผ๊ณ  ํ•˜๋Š”๋ฐ, ํ•ด๋‹น Beam๋งˆ๋‹ค ์„œ๋กœ ๋‹ค๋ฅธ timestep์—์„œ <END> token์„ ์„ ํƒํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด๋ ‡๊ฒŒ <END> token์ด ์„ ํƒ๋˜์–ด ์™„๋ฃŒ๋œ ๋ฌธ์žฅ์€ ํ›„๋ณด๋กœ ๋‚จ๊ฒจ๋†“๊ณ  ๊ณ„์† Beam Search๋ฅผ ์ด์–ด๊ฐ„๋‹ค.
<END> token ์ƒ์„ฑ ์ด์ „์— ํŠน์ • timestep์— ๋„๋‹ฌํ•˜๋ฉด ์ข…๋ฃŒํ•œ๋‹ค.
๋˜๋Š” ์ƒ์„ฑ์ด ์™„๋ฃŒ๋œ ๋ฌธ์žฅ์ด ํŠน์ • ๊ฐœ์ˆ˜์— ๋„๋‹ฌํ•˜๋ฉด ์ข…๋ฃŒํ•œ๋‹ค.

Beam Search๊ฐ€ ์ข…๋ฃŒ๋˜๊ณ  ์ƒ์„ฑ๋œ ํ›„๋ณด๋“ค ์ค‘ score๊ฐ€ ๊ฐ€์žฅ ๋†’์€ ๋ฌธ์žฅ์„ ์„ ํƒํ•œ๋‹ค.
ํ•˜์ง€๋งŒ ์ƒ์„ฑ๋œ ํ›„๋ณด ๋ฌธ์žฅ ์•ˆ์—๋Š” ์—ฌ๋Ÿฌ ๊ฐœ์˜ token์ด ์กด์žฌํ•˜๊ณ , ์ด๋“ค์„ ๋ชจ๋‘ ๋”ํ•˜๋Š” ํ˜•ํƒœ์ด๊ธฐ ๋•Œ๋ฌธ์— ์Œ์ˆ˜๊ฐ€ ๋ฐ˜๋ณต์ ์œผ๋กœ ๋”ํ•ด์ ธ ๊ธด ๋ฌธ์žฅ์—๋Š” ์ž๋™์œผ๋กœ ๋‚ฎ์€ score๊ฐ€ ํ• ๋‹น๋œ๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ๋‹ค.
๋˜ํ•œ ๋†’์€ ํ™•๋ฅ ์˜ token์„ ์„ ํƒํ•˜์—ฌ ๋ฌธ์žฅ์„ ์ƒ์„ฑํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ฐ˜๋ณต์— ์ทจ์•ฝํ•˜๊ณ , ์‚ฌ๋žŒ์€ ํ™•๋ฅ ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ง์„ ํ•˜์ง€ ์•Š์•„ ํ˜„์‹ค ์„ธ๊ณ„์™€ ๋ถ€ํ•ฉํ•˜์ง€ ์•Š๋Š”๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ๋‹ค.

Sampling

์œ„ ๋‹จ์ ๋“ค์„ ๊ฐœ์„ ํ•˜๊ธฐ ์œ„ํ•ด Decoding ๋ฐฉ๋ฒ•๋ก ์— Sampling์„ ์ ์šฉํ•œ๋‹ค. Sampling์€ ํ™•๋ฅ ์ด ๋‚ฎ์€ ๋‹จ์–ด์˜ token๋“ค๊นŒ์ง€๋„ ์„ ํƒ์ด ๋  ์ˆ˜ ์žˆ๋„๋ก ํ™•๋ฅ  ๋ถ„ํฌ์— ๋”ฐ๋ผ ๋‹จ์–ด๋ฅผ Random Samplingํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

Temperature

๊ทธ ์ฒซ ๋ฒˆ์งธ ๋ฐฉ๋ฒ•์œผ๋กœ Temperature ฯ„\tau๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ฯ„\tau์˜ ๋ฒ”์œ„๋Š” [0,โˆž)[0, \infin)์ด๋‹ค.

exp(z/ฯ„)โˆ‘iexp(zi/ฯ„)\frac{\text{exp}(z/\tau)}{\underset{i}{\sum}\text{exp}(z_i/\tau)}

์œ„ ์‹๊ณผ ๊ฐ™์ด ฯ„\tau๋กœ ๋‚˜๋ˆ„์–ด Softmax๋ฅผ ์ ์šฉํ•˜๋Š”๋ฐ, ๋‘ ๊ฒฝ์šฐ๋กœ ๋‚˜๋‰œ๋‹ค.
ฯ„ > 1
์ถœ๋ ฅ ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ๋ณด๋‹ค ํ‰ํƒ„ํ•˜๊ฒŒ ๋งŒ๋“ค์–ด ํ”์น˜ ์•Š์€ ๋‹จ์–ด๊ฐ€ ๋“ฑ์žฅํ•  ํ™•๋ฅ ์„ ๋†’์ธ๋‹ค. ๊ฒฐ๊ณผ์ ์œผ๋กœ ์ƒ์„ฑ ๋ฌธ์žฅ์˜ ๋‹ค์–‘์„ฑ์ด ์ฆ๊ฐ€ํ•œ๋‹ค.
์ด๋•Œ ฯ„โ†’โˆž\tau โ†’ \infin์ด๋ฉด Uniform ๋ถ„ํฌ๊ฐ€ ๋œ๋‹ค.
ฯ„ < 1
์ถœ๋ ฅ ํ™•๋ฅ  ๋ถ„ํฌ์˜ ์ฐจ์ด๋ฅผ ๋ณด๋‹ค ๊ทน๋Œ€ํ™”ํ•œ๋‹ค. ๋”ฐ๋ผ์„œ ์›๋ž˜ ๋‚˜์™€์•ผ ํ•  ๋‹จ์–ด์˜ ํ™•๋ฅ ์ด ์ฆ๋Œ€๋˜์–ด ์ƒ์„ฑ ๋ฌธ์žฅ์˜ ์ •ํ™•์„ฑ์ด ์ฆ๊ฐ€ํ•œ๋‹ค.
์ด๋•Œ ฯ„โ†’0\tau โ†’ 0์ด๋ฉด One-hot ๋ถ„ํฌ๊ฐ€ ๋˜์–ด Greedy Decoding๊ณผ ๋™์ผํ•œ ํ˜•ํƒœ๊ฐ€ ๋œ๋‹ค.

์œ„ ๋‘ ๊ฐ€์ง€ ๋‚ด์šฉ์€ ์•„๋ž˜ ๊ทธ๋ฆผ์œผ๋กœ ๋ณด๋‹ค ์ž˜ ์ดํ•ดํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค.

์ถœ์ฒ˜: Shivam Mehta github.io

Top-k Sampling

์ƒ์„ฑ ํ™•๋ฅ ์ด ๋†’์€ token์„ k๊ฐœ๋งŒ ์„ ํƒํ•˜์—ฌ ๊ทธ ์ค‘์—์„œ๋งŒ Sampling์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. ํ•˜์ง€๋งŒ ํ™•๋ฅ ์˜ ์ฐจ์ด๊ฐ€ ํฐ ๋ถ„ํฌ์˜ ๊ฒฝ์šฐ์—๋Š” ๋„ˆ๋ฌด ํฌ์†Œํ•œ ๋‹จ์–ด๊ฐ€ ์ถœ๋ ฅ๋  ์ˆ˜ ์žˆ๊ณ , ํ™•๋ฅ ์˜ ์ฐจ์ด๊ฐ€ ์ž‘์€ ๋ถ„ํฌ์˜ ๊ฒฝ์šฐ์—๋Š” ๋น„์Šทํ•œ ํ™•๋ฅ ์ž„์—๋„ ๋ฐฐ์ œ๋˜์–ด์•ผ ํ•˜๋Š” ๋ถ€์ž‘์šฉ์ด ์กด์žฌํ•œ๋‹ค.

Top-p Sampling

Nucleus Sampling์ด๋ผ๊ณ ๋„ ๋ถˆ๋ฆฌ๋Š” Top-p Sampling์€ Top-k Sampling์˜ ๋ถ€์ž‘์š”์„ ์™„ํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๊ณ ์•ˆ๋˜์—ˆ๋‹ค. Saplingํ•  token๋“ค์˜ ๋ˆ„์  ํ™•๋ฅ ํ•ฉ์ด p๊ฐ€ ๋  ๋•Œ๊นŒ์ง€ token๋“ค์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด๋‹ค.
๋”ฐ๋ผ์„œ ํ™•๋ฅ ์˜ ์ฐจ์ด๊ฐ€ ํฐ ๋ถ„ํฌ์˜ ๊ฒฝ์šฐ์—๋Š” ์ ์€ ์ˆ˜์˜ token๋“ค์ด ํ›„๋ณด๊ฐ€ ๋  ๊ฒƒ์ด๊ณ , ํ™•๋ฅ ์˜ ์ฐจ์ด๊ฐ€ ์ž‘์€ ๋ถ„ํฌ์˜ ๊ฒฝ์šฐ์—๋Š” ๋งŽ์€ ์ˆ˜์˜ token๋“ค์ด ํ›„๋ณด๊ฐ€ ๋  ๊ฒƒ์ด๋‹ค.
์ด๋Š” p๋ฅผ ์ ์ ˆํ•˜๊ฒŒ ์„ค์ •ํ•จ์— ๋”ฐ๋ผ ํฌ์†Œํ•œ ํ™•๋ฅ ์„ ๊ฐ€์ง€๋Š” token์„ ๋ฐฐ์ œํ•˜๋Š” ์—ญํ•  ๋˜ํ•œ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค.

profile
๐ŸŒˆ๐ŸŒผ๐ŸŒธโ˜€๏ธ

0๊ฐœ์˜ ๋Œ“๊ธ€