COCONUT - CoT๋ณด๋ค ํจ์จ๊ณผ ์ฑ๋ฅ์ด ์ข์ ํ์ต ๋ฐฉ๋ฒ(Chain of Continuous Thoughts)
์ ์์ฝ
CoT๋ฅผ ํฌํจํ ํ์ฌ๊น์ง์ ํ์ต ๋ฐฉ๋ฒ์ LLM์ ๋ค์ ๋จ๊ณ ์์ธก์ ์์ฐ์ด ํ ํฐ์ ํ์ฉํจ
๊ทธ๋ฌ๋ ๊ผญ ๊ทธ๋ ๊ฒ ํ ํ์๋ ์์. LLM์๊ฒ ๋ค์ ๋จ๊ณ ํ ํฐ์ผ๋ก ์์ฐ์ด ํ ํฐ ๋์ , ๋ง์ง๋ง ์๋์ธต์ ๊ฒฐ๊ณผ๋ฅผ ๋ฃ์ ์ ์๋ ์์ ๋ฅผ ์ฃผ๋ฉด ๋์ฑ ํจ์จ์ ์ด๊ณ , ์ฑ๋ฅ์ด ์ข์์ง
๊ฐ๋จ ๊ฒฐ๊ณผ:
ProntoQA:
CoT: 98.8 % Acc., 92.5 tokens
COCONUT: 99.8 % Acc., 9.0 tokens
Introduction
LLM์ด ๋
ผ์ฆํ ๋, ์์ฐ์ด ํ ํฐ์ ์ฌ์ฉํ๋ฉด
- ๋
ผ์ฆ๋ง๋ค ํ์ํ ๊ณ์ฐ๋์ด ๋ค๋ฅด๋ค๋ ์ฌ์ค์ ๋ฐ์ํ์ง ๋ชปํจ
- ๋
ผ์ฆ์ฌ์ฌ(reasoning chain)์ ๋๋ค์ ํ ํฐ์ ๋ฌธ์ฅ์ ๋งค๋๋ฝ๊ฒ ๋ง๋๋ ์ญํ ์ ํ ๋ฟ, ๋
ผ์ฆ์ ๋ฏธ์น๋ ์ํฅ์ ๋ฏธ๋ฏธํจ
์ด์ ์ฐ๊ตฌ ๊ฒฐ๊ณผ๋ค์ Related Work ์ฐธ๊ณ
- Related Work
- Chain of Thought (and its variants):
- ๊ฒฐ๊ณผ๋ฅผ ๋ด๊ธฐ ์ ์ ๋
ผ์ฆํ๋ ๋ฐฉ๋ฒ์ ํต์นญ. ํ์ต, prompting, ๊ฐํํ์ต์ ๋ชจ๋ ํฌํจํจ. ํจ๊ณผ์ ์ด์ง๋ง ์๊ธฐํ๊ท์ ์ฑ์ง์ด ๋ณต์กํ task์์๋ ์ฝ์ ์ผ๋ก ์์ฉ
- ์ฝ์ ์ ๊ทน๋ณตํ๊ธฐ ์ํด tree search๋ฅผ ์ถ๊ฐํ๊ฑฐ๋, search dynamics๋ฅผ ํ์ต์ํค๋ ๋ฐฉ๋ฒ์ด ์ ๊ธฐ๋จ (์ธ์ด ํ ํฐ ์ฌ์ฉ)
- LLM์ ์๋ ๋
ผ์ฆ:
- ๊ธฐ์กด ์ฐ๊ตฌ์์ ์๋ ๋
ผ์ฆ์ ์ค๊ฐ ๋จ๊ณ ํ ํฐ์ ์ง์นญํจ
- Transformer์ ์ค๊ฐ ๋จ๊ณ ํ ํฐ์ ๋ถ์ํด ๋ณด๋ฉด, CoT๋ฅผ ์์ฑํ๋๋ผ๋ ์ค๊ฐ ๋จ๊ณ ํ ํฐ์ ์์ฑ๋ CoT์๋ ๋ค๋ฅธ ๋
ผ์ฆ ๊ณผ์ ์ ๊ฑฐ์น๋ค๋ ์ ์ด ๋ฐ๊ฒฌ๋จ
- (unfaithfulness of CoT reasoning) โ ์๋ ๋
ผ์ฆ์ ์ ๋๋ก ํ์ฉํ๊ณ ์์ง ๋ชปํจ
- ์๋ ๋
ผ์ฆ์ ๋ ์ ํ์ฉํ๊ธฐ ์ํ ์ฐ๊ตฌ
- pause (Think before you speak)
- ๋
ผ์ฆ ํ ํฐ์ ์์ฑํ๊ธฐ ์ ์ pauseํ ํฐ์ ์์ฑํ๋ฉด์ ์๊ฐํ๊ฒ ํ๊ณ , ๋ง์ง๋ง pauseํ ํฐ ์ดํ์ ๊ฒฐ๊ณผ๋ง ์ฌ์ฉ
- (pretrain, finetune ํ์)
- implicit-CoT (From Explicit CoT to Implicit CoT)
- ํ์ต ๋จ๊ณ๋ณ๋ก ๋
ผ์ฆ ํ ํฐ์ ์ค์ โ COCONUT์์๋ ์ฑํ
COCONUT์ LLM์๊ฒ ์ํ ๋ ์์ฐ์ด ํ ํฐ์ผ๋ก ๋ณํํ ์ ์๊ฒ ํ๋ ๋ฐฉ๋ฒ
- ์ธ์ด ๋ชจ๋์ ์๋ ๋ชจ๋ ์ฌ์ด์ ๋ณํ์ด ๊ฐ๋ฅ
- ๋ณํ์
<bot>
, <eot>
์ ํน์ ํ ํฐ์ ์ฌ์ฉ
- ์๋ ๋ชจ๋์์๋ ์๋์ธต(Latent)์ ๊ฒฐ๊ณผ๋ฅผ ๋ค์ ํ ํฐ์ผ๋ก ์ฌ์ฉ
- ์ธ์ด ๋ชจ๋์์๋ ์ผ๋ฐ์ ์ธ LLM์ผ๋ก ์๋
- ๋ค๋จ๊ณ ํ์ต๋ฒ ์ ์ฉ
- ๋ถ์ ๊ฒฐ๊ณผ, ์๋ ๋ชจ๋์ ํ ํฐ์ ๊ฐ๋ฅํ ๋ค์ ์ํ๋ฅผ ์ค์ฒฉํด์ encodeํจ
- ์ด๋ CoT์์๋ ๋ถ๊ฐ๋ฅ
- ๋
ผ์ฆ ๊ณผ์ ์ BFS์ ๋น์ทํ ๊ตฌ์กฐ๋ก ๋ง๋ฆ
- ์ฅ๊ธฐ ๊ณํ์ด ํ์ํ ์์
์ผ์๋ก, CoT๋ณด๋ค ํจ์จ์ ์ด๋ฉด์๋ ๋ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋
COCONUT: Chain of Continuous Thought
LLM in a nutshell
์ผ๋ฐ์ ์ธ LLM ๊ตฌ์กฐ์์
-
์
๋ ฅ sequence๋ ํ ํฐ๋ณ๋ก ๋ถํด๋์ด embedding function e
๋ฅผ ๊ฑฐ์น๋ค.
E_t: ํ ํฐ embedding sequence = [e(x1), โฆ , e(x_t)]
-
์ดํ ํธ๋์คํฌ๋จธ ๊ธฐ๋ฐ ๋ชจ๋ธ์ ๊ฑฐ์ณ hidden state seqence H_t๊ฐ ๋๋ค
h_t: ๋ง์ง๋ง ํ ํฐ์ hidden state
-
์ต์ข
์ ์ผ๋ก language model head W๋ฅผ ๊ฑฐ์น ๊ฒฐ๊ณผ์ softmax๋ฅผ ์ทจํ๋ฉด ๋ค์ ์์ฐ์ด ํ ํฐ์ ๋ํ ํ๋ฅ ๋ถํฌ๊ฐ ๋๋ค.
COCONUT ๊ตฌ์กฐ์์
<bot>
, <eot>
์ ํน์ ํ ํฐ์ ์ฌ์ฉํ์ฌ ์๋ ๋ชจ๋๋ฅผ ํ๊ธฐํ๋ค.
- ์๋ ๋ชจ๋์์๋ e(xk) ๋์ h{k-1}์ ์ฌ์ฉํ๋ค.
- i๋ฒ์งธ ํ ํฐ์ ์ฒ๋ฆฌํ ๋, ๋ embedding์ ๊ฑฐ์ณ์(e(x_i)) ์๋ ๊ฒฐ๊ณผ h_i๋ฅผ ๋ธ๋ค.
- (i+1)๋ฒ์งธ ํ ํฐ์ ์ฒ๋ฆฌํ ๋, h_i์ W๋ฅผ ์ ์ฉํ์ง ์๊ณ (W h_i ์์)
- softmax๋ฅผ ์ ์ฉํ์ง ์๊ณ (Softmax (W h_i) ์์)
- embedding์ ์ ์ฉํ์ง ์๊ณ (e(Softmax (W h_i)) ์์)
- h_i๋ฅผ ๋ค์ ํธ๋์คํฌ๋จธ ์
๋ ฅ sequence์ ์ผ๋ถ๋ก ํ์ฉํ๋ค
- ์
๋ ฅ sequence์์ ๊ฐ i๋ฒ์งธ(x_i = )์ ์๊ณ , ๊ฐ j๋ฒ์งธ์ ์๋ ๊ฒฝ์ฐ, E_k(i < k < j)๋ ๋ค์๊ณผ ๊ฐ๋ค
- Ek = [e(x_1), โฆe(x_i), h_i, h{i+1}, โฆ, h_{k-1}]
- h_t๋ ์ต์ข
normalization layer๋ฅผ ๊ฑฐ์น ๊ฒฐ๊ณผ์ด๋ฏ๋ก, ํฌ๊ธฐ๊ฐ ํฌ์ง ์๋ค
- ๋ค์ ์ธ์ด ํ ํฐ ํ๋ฅ ๋ถํฌ Softmax (W h_t)๋ฅผ ๊ตฌํ ํ์๋ ์์ผ๋, ๋ถ์ ๋ชฉ์ ์ผ๋ก๋ ์ฌ์ฉ ๊ฐ๋ฅํ๋ค
- ๋ง์ง๋ง ์๋ ํ ํฐ์ ํ์ฌ์ ๋
ผ์ฆ ์ํฉ์ ํํํ๊ณ , ์ด๋ฅผ continous thought๋ผ๊ณ ๋ช
๋ช
ํ๋ค.
ํ์ต ๊ณผ์
CoT ๋ฐ์ดํฐ๋ฅผ ์ง๋ ๋ ์ผ์, ์ง๋ฌธ์ ์
๋ ฅ์ผ๋ก ํ๊ณ ๋
ผ์ฆ ๋จ๊ณ๋ฅผ ๊ฑฐ์ณ ๋ต์ ๋ด๋๋ก ํ์ตํ๋ค.
ํ์ต ๊ณผ์
- stage 0์์๋ CoT ๋ฐ์ดํฐ๋ฅผ ๊ทธ๋๋ก ํ์ตํ๋ค.
- ๊ฐ ๋จ๊ณ(stage)๋ง๋ค ์์์๋ถํฐ CoT ๋
ผ์ฆ ํ ํฐ 1 ๊ฐ๋ฅผ c ๊ฐ์ continuous thought ํ ํฐ์ผ๋ก ๋์ฒดํ๋ค
- ๋จ๊ณ๋ฅผ ์์ํ ๋๋ง๋ค optimizer state๋ฅผ resetํ๋ค
- Negative Log Likelihood loss๋ฅผ ์ฌ์ฉํ๊ณ , ์ง๋ฌธ๊ณผ ์๋ ์๊ฐ์ด ์๋, ์์ฐ์ด ๋
ผ์ฆ ๊ณผ์ ์์๋ง loss๋ฅผ ๊ณ์ฐํ๋ค
- ์๋ ์๊ฐ์ ์ ๊ฑฐ๋ ์์ชฝ ์์ฐ์ด ํ ํฐ์ ๋์ด๋ฆฌ๊ฑฐ๋ ์์ถํ๋ ๋ฐฉํฅ์ด ์๋, ๋ฏธ๋ ๋
ผ์ฆ ๋จ๊ณ๋ฅผ ์ ์์ธกํ๋๋ก ํ์ต๋๋ค
- It is important to note that the objective does not encourage the continuous thought to
compress the removed language thought, but rather to facilitate the prediction of future reasoning.
ํ์ต ๊ณผ์ - ์ธ๋ถ
- continous thought๋ back-propagation์ผ๋ก grad ๊ณ์ฐ ๊ฐ๋ฅ โ ํ์ต ์ฉ์ด
- n ๊ฐ์ ์๋ ์๊ฐ์ ๋ผ์๋ฃ์ ๋จ๊ณ์์, (n+1) forward pass๊ฐ ํ์ํจ
- ๊ฐ ์๋ ํ ํฐ์ ๋ง๋ค๊ธฐ ์ ๊น์ง๋ ๊ฐ์ ๋ชจ๋ฅด๋(GT ์์) ์ผ๋ฐ์ ์ผ๋ก ์ฌ์ฉํ๋ teacher forcing๋ฅผ ์ฌ์ฉํ ์ ์์
- ์๋ ํ ํฐ์ ๋ง๋๋ ๊ณผ์ ์ ์ง๋ ฌ โ ๋ณ๋ ฌํ๋ ํด๊ฒฐํด์ผ ํ๋ ๊ณผ์
์ถ๋ก ๊ณผ์
์ง๋ฌธ์ ์งํ์ ํ ํฐ์ด ์ค๊ณ , ์ ํด์ง ๊ฐ์์ continous thought ์ดํ์ ํ ํฐ์ด ๋ฑ์ฅ
continuous thought ํ ํฐ์ ๊ฐ์๋ ๋ณ๊ฐ์ model๋ก ์ธ์ ๋๋ผ์ง ์ถ๋ก ํ๋ ๊ฒ๊ณผ ์ ํด์ง ๊ฐ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ ๋ชจ๋ ์ ๋์ด์ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ ์ฑํํจ
Experiments
3 ๊ฐ์ dataset ํ์ฉ / GT์์ ๋น๊ต๋ก ์ ํ๋ ๊ณ์ฐ / ์ ๋ต ๋ด๊ธฐ๊น์ง ํ์ํ๋ ์ถ๊ฐ ํ ํฐ๋ ํ๊ธฐ
Datasets
- GSM8k: ์ํ ๋
ผ์ฆ
- ProntoQA: ๋
ผ์ฆ
- ํธ๋ฆฌ ๊ตฌ์กฐ์ ๋งฅ๋ฝ์ ์์๋ก ์์ฑํ๊ณ , ์์ฐ์ด๋ก ์ฃผ์ด์ง
- ProsQA: ๋
ผ์ฆ
- ProntoQA๊ฐ ๋๋ฌด ์ฌ์ ๋
ผ์ฆ ๋ฅ๋ ฅ์ ํ๊ฐํ๊ธฐ์ ๋ถ์ ์ ํจ
- DAG ๊ตฌ์กฐ์ ๋งฅ๋ฝ์ ์์๋ก ์์ฑํ๊ณ , ์์ฐ์ด๋ก ์ฃผ์ด์ง
- Each problem is structured as a binary question: โIs [Entity] a [Concept A] or [Concept B]?โ
- The graph is constructed such that a path exists from [Entity] to [Concept A] but not to [Concept B].
Experimental Setup
base: pre-trained GPT-2
์ํ ๋
ผ์ฆ: c = 2, 3 ๋จ๊ณ๊น์ง๋ ๋
ผ์ฆ ์ธ์ด ํ ํฐ์ 1 ๊ฐ์ฉ ์์ ๋ค๊ฐ, 4 ๋จ๊ณ์์๋ continuous thought ํ ํฐ ๊ฐ์๋ฅผ ์ ์งํ ์ฑ๋ก ๋
ผ์ฆ ์ธ์ด ํ ํฐ์ ๋ชจ๋ ์์ค๋ค. 3 epoch / stage โ ๊ธด ์ค๋ช
long-tail์ ๊ฐ๊ฑดํด์ง
๋
ผ๋ฆฌ ๋
ผ์ฆ: c =1, ๋ ๋ฐ์ดํฐ์
๋ชจ๋ ๋
ผ์ฆ ์ธ์ด ํ ํฐ์ ์ต๋๊ฐ์ด 6์ด๋ฏ๋ก, ํ์ต ๋จ๊ณ๋ฅผ 6 ๋จ๊ณ๋ก ์ค์ ํ๋ค. 5 epoch / stage
๋ ๋์๊ฐ ๋จ๊ณ๊ฐ ์์ผ๋ฉด 50 epoch๊น์ง ๋ง์ง๋ง ๋จ๊ณ๋ก ํ์ตํ๋ค.
์ถ๋ก ์์๋ continuous thought ํ ํฐ์ ๊ฐ์๋ฅผ ํ์ต ์ ๋ง์ง๋ง ๋จ๊ณ์์ ์ฌ์ฉํ๋ ๊ฐ์์ ์ผ์น์ํจ๋ค
Results & Discussion
Chaining continous thoughts enhances reasoning
GSM8k์์, COCONUT์ด iCoT๋ณด๋ค ์ข์ ๊ฒฐ๊ณผ๋ฅผ ์ป์๊ณ , pause as thought๋ณด๋ค๋ ์๋ฑํ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ์ป์๋ค. ์ด๋ COCONUT์ด ์ผ๋ฐํ ์ฑ๋ฅ์์ ๋ ๋ซ๋ค๋ ์๋ฏธ. (pause ํ ํฐ์ด ๋ณ๋ ฌํ์ ์ ๋ฆฌํ์ง๋ง)
c๋ฅผ 0 โ 1 โ 2๋ก ๋ณํ์ํฌ ๋ ์ฑ๋ฅ์ด ๊พธ์คํ ์ฌ๋ผ๊ฐ๋๋ฐ, ์ด๋ COCONUT๋ CoT์์์ ์ฐ์ ํจ๊ณผ๋ฅผ ๋ณด๊ณ ์๋ ๊ฒ์ผ๋ก ํด์ ํ ์ ์๋ค
- ํ ํฐ์ ์ฐ์์ ์ผ๋ก ์ฐ๊ฒฐํ๋ฉด ๊ณ์ฐ๋์ ๋ฐ๋ผ ์ฑ๋ฅ์ด ์ฌ๋ผ๊ฐ๋ ํจ๊ณผ
๋
ผ๋ฆฌ ๋
ผ์ฆ ๋ฐ์ดํฐ์
์์๋ COCONUT๊ณผ iCoT๊ฐ ๋ชจ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์๋๋ฐ, ์ด ๋ฐ์ดํฐ์
์์๋ ๊ณ์ฐ๋์ด ๋ณ๋ชฉ์ด ์๋๋ผ๋ ๊ฒ์ ๋ณด์ฌ์ค๋ค
๋ณต์กํ ๋
ผ์ฆ์ ๊ฒฝ์ฐ, ๋ ์ฅ๊ธฐ์ ์ธ ๊ด์ ์์ ๊ฐ step์ ํ๊ฐํ ํ์๊ฐ ์๋๋ฐ, ProsQA์ ๋ณต์กํ DAG๋ ๊ณํ ๋ฅ๋ ฅ์ ์๊ตฌํ๋ค. CoT๋ ๊ฑฐ์ ์ฑ๋ฅ ํฅ์์ด ์๋ ๋ฐ๋ฉด, COCONUT๊ณผ iCoT๋ ์๋นํ ์ฑ๋ฅ ํฅ์์ ๋ณด์ฌ์ค๋ค
The LLM still needs guidance to learn latent reasoning
๋จ๊ณ์ ์ผ๋ก ๋
ผ์ฆ ์ธ์ด ํ ํฐ์ ์ค์ฌ ๋๊ฐ๋ ๊ฒฝ์ฐ๊ฐ ์๋, ๋ชจ๋ ๋
ผ์ฆ ์ธ์ด ํ ํฐ์ ์์ ๋ ๋ฐฉ์ฑ(no-curriculum)์ผ๋ก ํ์ตํ๋ฉด, no-CoT์ ๋น์ทํ๋ค (continuous thought์ ์๋ฏธ๊ฐ ์๋ค)
**Continuous thoughts are efficient representations of
reasoning**
์ฒซ ์๋ ํ ํฐ์ ์ธ์ด ํ ํฐ์ผ๋ก decodeํ๊ธฐ ์ํด LM head๋ฅผ ํต๊ณผ์ํค๋ฉด 180, 9๋ฅผ ๋๊ฒ ๊ฐ๋ ๋ถํฌ๊ฐ ๋์จ๋ค. ์ด๋ (3ร3ร60 = 9ร60 = 540, or 3ร3ร60 = 3ร180 = 540) ์ ์ค๊ฐ ๊ณผ์ ์ ๋ณด์ฌ์ค๋ค. ๋ํ ์ฌ๋ฌ ๋
ผ์ฆ์ ๋ฐฉํฅ์ด ์ค์ฒฉ๋์ด ์์์ ํ์ธํ ์ ์๋ค
Understanding the Latent Reasoning in Coconut
์ด ์ฅ์์๋ ์๋ ๋
ผ์ฆ ๊ณผ์ ์ ๋ถ์. ์ด๋ฅผ ์ํด ์ธ์ด ๋ชจ๋์ ์๋ ๋ชจ๋๋ฅผ ๋ ์์ ๋กญ๊ฒ ์๋ค๊ฐ๋ค ํ ์ ์๋ COCONUT ๋ณ์ด๋ฅผ ์ฌ์ฉํ๋๋ฐ, ์ผ๋ฐ COCONUT๊ณผ๋ ๋ค์์ ์ฐจ์ด๋ฅผ ๊ฐ์ง.
์ผ๋ฐ COCONUT
- ํ์ต์ ๋ง์ง๋ง ๋จ๊ณ์์๋ โ์ ํด์ง ์ต๋์ continuous thought ํ ํฐโ์ ๊ฐ์ง
- ์ถ๋ก ์์ โ์ ํด์ง ์ต๋์ continuous thought ํ ํฐโ๋งํผ ์๊ฐ ํ ์ธ์ด ๋ชจ๋๋ก ์ ํ
์๋ ๋
ผ์ฆ ๋ถ์์ฉ COCONUT
- ํ์ต์ ๋ชจ๋ ๋จ๊ณ์์ 0.3์ ํ์ต ๋ฐ์ดํฐ๋ฅผ ๋ค๋ฅธ ๋จ๊ณ์ ํ์ต ๋ฐ์ดํฐ๋ก ๋์ฒด โ ์ด์ ๋จ๊ณ๋ฅผ ๊น๋จน์ง ์๊ฒ ๋จ
- ์ถ๋ก ์์ k ํ ํฐ๋งํผ ์๊ฐ ํ ์ธ์ด ๋ชจ๋๋ก ์ ํ
์๋ ๋
ผ์ฆ ๋ถ์์ฉ COCONUT์ ์ฌ์ฉํ์ฌ ์์ ์๋ ๋ชจ๋์ ์์ ์ธ์ด ๋ชจ๋๋ฅผ ์ฌ์ด๋ฅผ ๋น๊ต(์ฑ๋ฅ ๋ฑ)ํด ๋ณผ ์ ์์.
์ด ๊ฒฐ๊ณผ๋ฅผ ํตํด ์๋ ๋
ผ์ฆ ๊ณผ์ ์ด tree search์ ์ ์ฌํจ์ ๋ฐํ๊ณ , ์๋ ๋
ผ์ฆ์ด LLM์ ํ๋จ์ ์ ๋์์ ์ฃผ๋์ง ๋ถ์ํจ
Experimental Setup
๋ณ์ด COCONUT(k in {0, 1, 2, 3, 4 ,5, 6})์ ProsQA๋ฅผ ์ฌ์ฉํ์ฌ ํ๊ฐ.
ํ๊ฐ ํญ๋ชฉ:
- ์ ํ๋: ์ต์ข
๋ต์ด ๋ง์๋์ง ํ๊ฐ
- ๋
ผ์ฆ ๊ณผ์ :
- ProsQA๋ DAG๋ก ์ด๋ฃจ์ด์ง ๋ฐ์ดํฐ์
์ด๋ฏ๋ก, ๋ชจ๋ธ์ด ์ถ๋ ฅํ๋ ์ธ์ด ๋
ผ์ฆ๋ ๊ทธ๋ํ์์์ ๊ฒฝ๋ก๊ฐ ๋จ
- ์ธ์ด ๋
ผ์ฆ์ ๋ฐฐํ์ ์ธ ๋ค์ 6๊ฐ์ง ๋ฒ์ฃผ๋ก ๋ถ๋ฅ ๊ฐ๋ฅ
- Correct Path: ์ ๋ต์ ๋ง์ถ์๊ณ , ๊ฐ์ฅ ์งง์ ๊ฒฝ๋ก์
- Longer Path: ์ ๋ต์ ๋ง์ถ์์ง๋ง, ๋ ์งง์ ๊ฒฝ๋ก๊ฐ ์กด์ฌํจ
- Hallucination: ๊ทธ๋ํ์ ์กด์ฌํ์ง ์๋ edge๋ฅผ ์ด์ฉํ๊ฑฐ๋, ๋์ด์ ธ ์๋ ๊ฒฝ๋ก์
- Wrong Target: ์ ํจํ ๊ฒฝ๋ก์ด์ง๋ง, ์ค๋ต์ ๋์
- Correct Label: ๋
ผ์ฆ ํ ํฐ ์์ด ์ ๋ต์ ๋์ (no-CoT or large k)
- Incorrect Label: ๋
ผ์ฆ ํ ํฐ ์์ด ์ค๋ต์ ๋์ (no-CoT or large k)
Interpolating between Latent and Language Reasoning
- k๋ฅผ ๋๋ ค๊ฐ์๋ก ์ ํ๋์ ์ฌ๋ฐ๋ฅธ ๋
ผ์ฆ ๊ณผ์ (Correct Path, Correct Label)์ด ์ฌ๋ผ๊ฐ. ๋ํ Hallucination๊ณผ Wrong Target์ ์ค์ด๋ฆ
- k=0 ๊ณผ CoT๋ฅผ ๋น๊ตํ๋ฉด, ๋ ๋ชจ๋ธ ๋ชจ๋ ์์ ํ ์ธ์ด ๋ชจ๋๋ก ๋์ํจ์๋ k=0 COCONUT์ ์ ํ๋๊ฐ ๋ ๋๊ณ , ๋
ผ์ฆ ๊ณผ์ ์์ Correct Path ๋น์จ์ ๋ ๋๊ณ , Hallucination ๋น์จ์ ๋ ๋ฎ์์ (Wrong Target์ ๋ ๋์๋ณด์..)
- ์ด๋ ํ์ต ๊ณผ์ ์์ ๋ค๋ฅธ ๋จ๊ณ์ ํ์ต ๋ฐ์ดํฐ๋ฅผ ์์ ์ํฅ์ผ๋ก, ํ๊ธฐ ๋จ๊ณ์ ํ์ต ๋ฐ์ดํฐ๋ ์ ์ชฝ์ ๋
ผ์ฆ ๊ณผ์ ์ด ์๋ต๋์ด ๋ชจ๋ธ์ ๊ณํ ๋ฅ๋ ฅ์ ํฅ์์ํด. ์ด์ ๋ฐํด CoT์์๋ ๋ชจ๋ธ์ด ์ธ์ ๋ ์งํ ํ ํฐ์ ์์ธกํ๋๋ก ํ์ต๋์ด, ๊ทผ์์(shortsighted)์ธ ๋ชจ์ต์ ๋ณด์
- ์ด ์์์์ CoT๋ Hallucination, k=1 COCONUT์ Wrong Target์ผ๋ก ๋น ์ง์ง๋ง, k=2 COCONUT์ CORRECT PATH๋ฅผ ์ ๋
- ์ด๋ฅผ ํตํด ์ด๊ธฐ ์๋ ์๊ฐ ํ ํฐ์์๋ ์ด๋ค edge๋ฅผ ํํ ์ง ์ด๋ ค์ํ๋ค๋ ๊ฒ์ ์ ์ ์์
- ๊ฐ ๋จ๊ณ๋ง๋ค ์ธ์ด ํ ํฐ ํ๋๋ฅผ ๊ณจ๋ผ์ผ ํ๋ ์ธ์ด ๋ชจ๋์ ๋ฌ๋ฆฌ, ์๋ ๋ชจ๋์์๋ ๊ฒฐ์ ์ ๋
ผ์ฆ์ ๋๊น์ง ๋ฏธ๋ฃฐ ์ ์์ผ๋ฏ๋ก ์๋ ๋
ผ์ฆ ๊ณผ์ ์ด ์งํ๋ ์๋ก ์ค๋ต์ ์ ์ง์ ์ผ๋ก ๊ฑธ๋ฌ๋ด์ด ์ ํ๋๋ฅผ ๋์ด๋ ๊ฒ์ ํ์ธํ ์ ์์
Interpreting the Latent Search Tree
- ์๋ ์๊ฐ ํ ํฐ์ด ๋ค์ step์ด ๋ ์ ์๋ ์ฌ๋ฌ ํ๋ณด๋ฅผ ์ค์ฒฉํ์ฌ encodeํ๋ค๋ ์ ์์, search tree๋ก ํด์ํ ์ ์๋ค.
- ๋ชจ๋ frontier node๋ฅผ ๊ฐ์ ๋น์ค์ผ๋ก ๋ค๋ฃจ๋ BFS์๋ ๋ค๋ฅด๊ฒ, ๋ชจ๋ธ์ ๋ ๊ฐ์น์๋ node๋ฅผ ์ฐ์ ํ๋ ๋ฅ๋ ฅ์ด ์๋ค.
- frontier node: ํ์ฌ ์ํ์์ ๋ฐฉ๋ฌธํ node์ ์ง์ ์ฐ๊ฒฐ๋, ๋ฐฉ๋ฌธํ์ง ์์ node
- Figure 6. ์์ ์ฒซ ๋ฒ์งธ step์ Alex์ ์์ node๋ฅผ ๊ณ ๋ฅด๋ ๊ณผ์ ์ด๊ณ , ๋ ๋ฒ์งธ step์ frontier node๋ ์์ node์ด๋ค.
- frontier node: ํ์ฌ ์ํ์์ ๋ฐฉ๋ฌธํ node์ ์ง์ ์ฐ๊ฒฐ๋, ๋ฐฉ๋ฌธํ์ง ์์ node
- ์ด ๊ณผ์
์์ k ํ ํฐ ํ ์ธ์ด ๋ชจ๋๋ก ๋ชจ๋ธ์ ์ ํํ๋ฉด
Every |Concept A| is a |Concept B|
์ ๊ท๊ฒฉํ๋ ๋ฌธ์ฅ์ด ์ฐ๋ฌ์ ์ถ๋ ฅ๋๋๋ฐ, Concept A๊ฐ ๋ฑ์ฅํ ๊ฐ๋ฅ์ฑ์ ๊ณ์ฐํ๋ฉด ๋ชจ๋ธ์ด ํด๋น node์ ์ผ๋ง๋ ํฐ ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํ๋์ง ๋ถ์ ๊ฐ๋ฅํ๋ค
- Figure 7.์ ์ข์ธก์ ์์๋ก ํ๋ฉด, Every๊น์ง ์ธ์ด ๋ชจ๋๋ก ์ถ๋ ฅํ ๋ค์, Every๋ฅผ Embedding Layer์ ํต๊ณผ์ํค๊ณ Transformer์ ํต๊ณผ์ํค๊ณ LM head, Softmax๊น์ง ํต๊ณผ์ํค๋ฉด ๊ฐ ์ธ์ด ํ ํฐ๋ณ ํ๋ฅ ๋ถํฌ๋ฅผ ์ป๋๋ฐ, ์ด ์์ ์์ โlempusโ์ ๊ฐ์น๋ฅผ โlempusโ๋ฅผ ์ด๋ฃจ๋ ๊ฐ ํ ํฐ(โleโ, โmpโ, โusโ)๊ฐ ๋ฑ์ฅํ ํ๋ฅ ์ ๊ณฑ์ผ๋ก ํ๊ฐํ ์ ์๋ค
- ํ๋ฅ ์ ๊ตฌํ๋ ค๋ฉด p(โle")p(โmpโ|โle)p(โusโ|โleโ,โmpโ)์ฌ์ผ ํ ํ
๋ฐ, ๋ชจ๋ forward pass๋ฅผ ์์ผ๋ณด์ง ์๋ ํ, ๊ณ์ฐ์ด ๋ถ๊ฐ๋ฅํ ๊ฒ ๊ฐ๋ค
- ๊ฐ์น๋ ํ๋ฅ ์ ์ ์๋ฅผ ๋ง์กฑํ์ง ์์ ๊ฒ ๊ฐ๋ค
- ๋ง์ฐฌ๊ฐ์ง๋ก, Figure 7.์ ์ฐ์ธก์ ์์ node์ ๊ฐ์น๋ฅผ ํ๊ฐํ๋ ๊ณผ์ ์ด๋ค. ์ข์ธก๊ณผ ๋น๊ตํด ๋ณด๋ฉด, ์ข์ธก์์๋ โsterpusโ๋ฅผ ๊ฑธ๋ฌ๋ด๊ธด ํ์ง๋ง ๋๋จธ์ง node์์๋ ํ๊ฐ๋ฆฌ๋ ๊ฒฝํฅ์ ๋ณด์ธ๋ค. ์๋ ์๊ฐ์ 1 ํ ํฐ ๋ ๊ฑฐ์น ์ฐ์ธก์์๋ rorphus์ ์ง์ค(๊ฐ์น=0.87)ํ๋ค๋ ์ ์ ํ์ธํ ์ ์๋ค.
- ์๋ ๋
ผ์ฆ ๋จ๊ณ๋ฅผ ๊ฑฐ์น ์๋ก top-1์ ๊ฐ์น๊ฐ top-2, top-3๋ฅผ ์๋ํ๋ ๊ฒ์ผ๋ก ๋ชจ๋ธ์ด ์ ์ฐจ ๋์ ๊ฐ์น๋ฅผ ๊ฐ๋ node์ ์ง์คํ๋ ์ ์ ํ์ธํ ์ ์๋ค.
Why is a Latent Space Better for Planning?
์๋ ๋
ผ์ฆ ๊ณผ์ ์ด search tree๋ก ๋์ํ๋ค๋ ์ ์ ํตํด ์๋ ๋
ผ์ฆ ๊ณผ์ ์ด ์ด์งธ์ ๋ชจ๋ธ์ ๊ณํ ๋ฅ๋ ฅ์ ํฅ์์ํค๋์ง ๊ฐ์ค์ ์ธ์ธ ์ ์๋ค.
Figure 6.์ ์์์์ ๊ฐ์น๊ฐ ๋ฎ๊ฒ ํ๊ฐ๋ โsterpusโ (Figure 7. ์ข์ธก)์ ๋๋จธ์ง ์ธ node์ ์ฐจ์ด์ ์ ๋์ด์ด๋ค.
โsterpusโ๋ leaf node๋ก target์ ์ด๋ฅด์ง ๋ชปํ๋ค๋ ์ฌ์ค์ ๋ฐ๋ก ํ์ธํ ์ ์๊ณ , ๋ค๋ฅธ node๋ ๋ ๋์ ๋ฐฉ๋ฌธํ ์์์ด ์์ง ๋จ์์์ด, ํ๊ฐํ๊ธฐ ๋ ์ด๋ ต๋ค.
๐ก ๊ฐ์ค: ๋ ๋ฎ์ node๋ ์ ํํ ํ๊ฐํ๊ธฐ ๋ ์ฝ๋ค.
(nodes with lower heights are easier to evaluate accurately)
Figure 9. ๋ ์ด ๊ฐ์ค์ ์ผ์นํ๋ ํจํด์ ๋ณด์ฌ์ค๋ค. ์ค๋ต node์ ์ ๋ต node์ ๊ฐ์น์ ์ฐจ์ด๊ฐ ํด ์๋ก ๋ชจ๋ธ์ด ๊ฐ์น๋ฅผ ์ฌ๋ฐ๋ฅด๊ฒ ํ๊ฐํ๊ณ ์๋ค๊ณ ๋ณผ ์ ์๋๋ฐ, ๋์ด๊ฐ ๋ฎ์ node์ผ์๋ก ๋ ๊ทธ๋ํ์ ์ฐจ์ด๊ฐ ํฌ๋ค.
๊ฒฐ๋ก ์ ์ผ๋ก, ๊ฒฝ๋ก ํ์ ์ ๋ค๋ก ๋ฏธ๋ฃฐ์๋ก search ๋ฒ์๋ฅผ ํ์ฅํ์ฌ ์ข
๊ฒฐ ์ํ๊น์ง ํ์ํ ์ ์๊ฒ ๋๊ณ , ๋ชจ๋ธ์ ์ ๋ต node์ ์ค๋ต node๋ฅผ ๋ถ๋ฅํ๋ ๋ฅ๋ ฅ์ด ํฅ์๋๋ค
๊ฒฐ๋ก
์ฐ์์ ์ธ ์๋ ๊ณต๊ฐ์์ ๋
ผ์ฆํ๋ COCONUT์ LLM์ ๋
ผ์ฆ ์ฑ๋ฅ์ ํฅ์์ํจ๋ค. ๋ํ ์๋ ์๊ฐ์ search tree์ ์ ์ฌํ ๊ตฌ์กฐ๋ฅผ ๋ณด์ธ๋ค.
- CoT๋ ๊ฐ step๋ง๋ค ํ๋ฅ ๋ถํฌ๋ฅผ ํ๋์ ์ธ์ด token์ผ๋ก collapse โ ๋ถ์ฐ์
TODO
- ํจ์จํ
- COCONUT์ pretrain์์๋ ํ์ฉ