Attention의 원리, 사용 이유

용가리·2024년 8월 19일
0

반갑습니다.
Attention의 원리와 사용이유에 대해 설명해보겠습니다.

Attention이 뭔가요?

어텐션 메커니즘은 트랜스포머에 핵심 알고리즘으로 현재 지피티나 라마 바드 등 다양한 LLM모델에 활용되고 있습니다. 자연어처리 뿐만 아니라 CV에도 널리 활용되고 있으므로 꼭 알아야 한다고 생각하는 기술입니다.
저도 아직 모르고 이해가 안되는 부분이 많아 자주자주 이 글을 수정할 것 같습니다.

어텐션이 왜 등장했나요?

기존의 언어모델들은 대부분 RNN 계열을 합쳐서 encoder과 decoder로 구성하였습니다.
encoder에서 출력한 context vector를 decoder단의 ht1h_{t-1}로 입력하고, 거기에 <sos.>토큰을 입력으로 넣으면 출력으로 hth_{t}와 예측한 단어하나가 나옵니다.
그럼 그 두개를 또 입력으로 넣어서 출력 단어로 <eos.> 토큰이 나올때까지 반복하는 Sequence to Sequence 모델을 사용하였죠.

근데 RNN계열은 고질적인 문제가 있었는데, gradient가 explode하거나 vanish된다는 것입니다.

이게머선소리냐
RNN의 식을 살펴봅니다.

ht=tanh(Whhht1+Wxhxt)h_t = tanh(W_{hh} * h_{t-1} + W_{xh} * x_{t})

hth_tht1h_{t-1}에 대해 미분해보면

htht1=tanh(Whhht1+Wxhxt+bh)Whh\frac{\partial h_t}{\partial h_{t-1}} = \tanh'(W_{hh} h_{t-1} + W_{xh} x_t + b_h) \cdot W_{hh}

이 미분식의 의미는 현재 state가 과거 state에 얼마나 영향을 받느냐를 의미합니다.
식을 살펴보면 마지막단에 WhW_h가 곱해짐을 볼 수 있는데, backprop 과정에서 WhW_h가 한 10겹을 쌓으면 Wh10W_h^{10}과 같이 계속 제곱됩니다.
이녀석이 쌓일수록 gradient가 아주 커지거나 작아지게 되는것이지요.

이 문제를 고치고자 LSTM, GRU 등 다양한 모델들을 개발했지만 근본적으로는 고칠 수 없었습니다.
그래서 다른 접근을 해보자
해서 등장한 Attention입니다.

Attention의 원리가 뭔가요?

attention의 의도는 다음과 같습니다.
Sally likes Sam
Sam likes Sally
두 문장의 의미는 아예 다르지만, 사용된 단어는 같다.
두 문장의 벡터를 다르게 할 수 없을까?
같은 단어라도 문맥에 따라 다른 형태를 갖게 하자 !
어떻게 해야하지?

번역에서, Attention이 나오기 전까지 RNN 기반의 seq to seq 모델을 사용했습니다.

encoder 모델에서 출력된 context vector (그냥 제일 끝단의 Output입니다)를 decoder 모델의 초기 hidden state로 설정하고, 첫 input으로 sos(start of sentences) 토큰을 줍니다.
그럼 그 두개를 바탕으로 output이 나오는데, 이 output에 softmax를 취해서 나온 단어를 다시 입력으로 넣습니다.
이 과정을 while문으로 반복해서 eos(end of sentences)라는 토큰을 output으로 낼 때까지 돌리는거죠.

근데, 이러한 RNN 기반 모델은 단점이 너무 많았습니다.

1. 앞단의 내용을 제대로 담지 못한다. (제일 끝단만 봄)

2. 순차적으로 사용하기에 병렬화가 어렵다. (느리다)

3. RNN의 vanishig/exploding gradient 문제

이 중 1번의 문제를 해결한 것이 Attention입니다.

그럼 이걸 어떻게 구현하냐
저도 솔직히 헷갈립니다.
아는 대로 적어보겠습니다.

Query
Key
Value
총 세가지를 통해 Attention Value라는 값을 구합니다.

어텐션 함수를 정의하는데, 함수는 논문, 모델마다 조금 다를 수 있지만 근본은 동일합니다.

Attention(Q, K, V) = Attention Value

Attention Value는 쿼리(기준이 되는 단어벡터)와 Key(다른 단어 벡터들)과의 유사도를 구한 뒤, 이 유사도에 V를 가중합함으로써 이루어집니다.

유사도는 보통 Dot Product를 사용합니다.

그러니까 I love you 라는 문장을 번역할 때 Decoder에서 sos 다음으로 "나는"(Q) 이라는 단어가 나왔다 그러면 이 "나는" 의 단어 벡터를 I love you 의 세 단어 벡터들(K)과 내적합니다.

그러면 "나는" 이라는 벡터와 얼마나 방향이 같은지에 대해 알 수 있겠죠. 그 값을 WVW_V에 가중합 하여 새로운 context vector을 만듭니다.

이런 과정을 거치면, Q에 대해서 연관성이 높은 K들의 정보는 높게, 연관성이 낮은 정보들은 낮게 표현할 수 있습니다.

이러한 알고리즘을 통해 1번 문제를 해결했으니,, 나머지도 해결해야겠죠.
나머지 문제의 해결법에 대한 내용은 다음 포스트에 적도록 하겠습니다.

0개의 댓글