Attention

leeebs·2022년 5월 5일
0

Previous

Seq2Seq는 크게 Encoder, Decoder, Generator 3 서브모듈로 나눌 수 있다. Encoder는 문장을 받아서 context 벡터로 변환해주는 역할을 하고, context 벡터가 Decoder로 가면 이전 타임스텝의 입력을 받고 Generator와 맞물려 다음 타임스텝의 출력으로 뱉어낸다. 여기서의 문제점은 긴 문장이 들어올 경우 대처가 되지 않는다는 점이다. context 벡터에 담을 수 있는 정보가 한계가 있기 때문이다. 하지만 Attention을 통해 해당 문제를 해결하게 된다.

Attention

디코더는 인코더에 필요한 정보를 요청해서 검색을 해서 필요한 정보를 취해 오는 것
디코더의 현재 나의 상황에서 인코더에 쿼리를 날려서 검색을 하고 필요한 정보를 취해서 디코더의 현재 상태와 컨캣해 new hidden state 생성 Linear Transformation 현재 상태를 잘 반영하면서 좋은 검색 결과를 이끌어내는 쿼리를 만들어내기 위해서 사용 Attention은 Linear Transformation을 잘하는 방법을 배우는 것이다. 매칭되는 key가 similarity가 높도록 Query를 잘 변환해줘야 한다. Attention은 인코더로 넘겨 받은 정보가 부족하기 때문에 현재 상태의 필요한 정보를 linear transformation을 통해 쿼리를 잘 형성하고 잘 형성된 쿼리에 의해 key와 query의 similarity가 높을 것이고 그러면 그 similarityㄹ대로 weighted sum을 해서 value를 가져오고 context 벡터와 concat해 new hidden state를 생성한다

Equations


매 타임 스텝마다 디코더의 히든 스테이트가 인코더에 대해서 어텐션 작업을 수행해야 한다.
htdech_t^{dec}는 현재 타임스텝의 디코더 hidden state이다. 사이즈는 배치사이즈, 현재 타임스텝, 히든사이즈로 htdec=(bs,1,hs)h_t^{dec}=(b_s, 1, h_s) 정의된다. Linear transformation을 통해 key에 잘 히트되도록 쿼리를 잘 조작해야 하는데 WaW_a는 히든사이즈에서 히든사이즈로 가기 때문에 Wa=(hs,hs)|W_a|=(h_s, h_s)가 되고 이 둘을 곱하면
htdec Wa=(bs,hs)(bs,bs)|h_t^{dec}\cdot\ W_a|=(b_s,h_s)*(b_s,b_s)
(bs,hs)(b_s, h_s)가 되고 Linear transformation이 반영된 쿼리가 완성된다. 그리고 나서 이것을 인코더의 전체 타임스텝에 대해서 유사도를 검사해봐야 한다. 그때 전체 타임스텝 1부터 m까지의 인코더 h1:mencTh_{1:m}^{encT}의 사이즈는 배치사이즈, m, m의 히든사이즈에 transpose가 적용되었으므로 (bs,hs,m)(b_s, h_s, m)이 된다. (bs,hs)(b_s, h_s)(bs,1,hs)(b_s, 1, h_s)이 되므로 w=(bs,1,hs)(bs,hs,m)w=(b_s, 1, h_s)*(b_s,h_s,m)(bs,1,m)(b_s,1,m)이 된다. 이것의 의미는 미니 배치내의 각 샘플 별로 현재 타임스템에 대해서 인코더의 각 타임스텝 별로의 weight가 들어가 있다는 것이다. c=wh1:mencTc=w\cdot h_{1:m}^{encT}(bs,1,m)(bs,m,hs)(b_s,1,m)*(b_s, m, h_s)이므로 mm이 사라지고 (bs,1,hs)(b_s,1,h_s)가 되는데 이것이 context vector가 되어야 한다. 디코더의 hidden state를 re-define하는 것은 [htdec;c][h_t^{dec};c]로 concat을 수행하는데 이때의 사이즈는 (bs,1,hs2)(b_s, 1,h_s*2)가 된다. 때문에 WconcatW_concat을 통해 2xhidden_size를 hidden_size로 바꿔준다. 따라서 h~tdec\tilde{h}_t^{dec}의 사이즈는 (bs,1,hs)(b_s,1,h_s)가 된다. 이걸로 새로운 hidden state가 구해졌다. 이것을 softmax layer를 통과시키고 WgenW_gen에서 단어를 선택하고 확률분포가 나와 y^t=(bs,1,v)\hat{y}_t=(b_s,1,|v|)이 된다.

conclusion

Attention은 미분 가능한 key-value function으로 입력은 query, key, value가 된다. Attention이란 정보를 잘 얻기 위해 query를 변환하는 방법을 배우는 과정이라고 할 수 있다. Attention을 통해서 RNN의 hidden state의 한계를 극복할 수 있게 되었다. RNN의 단점을 상쇄시킬 수 있는 LSTM을 쓰더라도 context vector에 모든 정보를 담기에는 한계가 있기 때문에 Attention의 출현은 더 긴 길이의 입출력에도 대처할 수 있게 되었다.

profile
개발개발

0개의 댓글