어텐션은 쿼리(query)와 비슷한 값을 가진 키(key)를 찾아서 그 값(value)을 얻는 과정입니다. 파이썬의 딕셔너리 자료형은 다음과 같이 사용할 수 있습니다.
dic = {'computer':8, 'dog':2, 'cat':3}
이와 같이 key와 value에 해당하는 값들을 넣고 key를 통해 value 값에 접근할 수 있습니다. 다시 말하면 쿼리가 주어졌을 때 key값에 따라 value값에 접근할 수 있습니다.
def key_value_func(query):
weights = []
for key in dic.keys():
weights += [is_same(key, query)]
weight_sum = sum(weights)
for i, w in enumerate(weights):
weights[i] = weights[i] / weight_sum
answer = 0
for weight, value in zip(weights, dic.values()):
answer += weight * value
return answer
def is_same(key,query):
if key==query:
return 1.
else:
return .0
위 코드는 순차적으로 dic 변수 내부의 key 값들과 query 값을 비교하여 key가 같을 경우 weights 변수에 1.0을 더하고, 다를 경우 0을 더합니다. 그리고 weights를 weights의 총합으로 나누어 그 합이 1이 되도록 만들어 줍니다. 다시 dic 내부의 value 값들과 weights의 값에 대해 곱하여 더해줍니다. 즉, weight가 1.0인 경우에만 value값을 answer에 더합니다.
만약 computer, dog, cat과 puppy의 유사도가 각각 0.1, 0.9, 0.7이라면 key_value_func함수에 puppy를 테스트하면 값은 2.823이 나올 것입니다.
2.823 = .1 / (0.9+0.7+0.1) 9 + 0.9 / (0.9+0.7+0.1) 2 + 0.7 / (0.9+0.7+0.1) * 3
원래 is_same 함수는 0과 1로만 이루어진 불연속적인 값이었지만 이 둘의 유사도를 고려하면 0에서 1사이의 연속적인 값을 weights에 할당하여 key_value_func 함수를 수행할 수 있습니다.
인코더의 시점을 각각 1,2, ..., N이라고 하고, 은닉상태를 각각 , 디코더의 현재시점 에서의 은닉상태를 라고한다면, 번째 단어를 예측하기 위해서는 이전시점()의 은닉상태와, 이전시점의 출력단어가 필요하고,어텐션에서 출력 단어 예측에는 어텐션 값(attention value)라는 새로운 값을 필요로합니다. 어텐션에선 이 어텐션 값을 구하기위해 를 전치하고 각 은닉 상태와 내적을 수행합니다.
모든 시점에서 와 를 내적한 값이 라고 한다면 어텐션 값인 는 가 됩니다.
이제 어텐션 값을 구했으니, 이 값과 를 결합(concatenate)하여 하나의 벡터로 만들어 준 뒤, 이 값을 라고 하고, 이 를 가중치 행렬과 곱한 후에 tanh함수를 지나도록 해서 입력으로 활용합니다.
이 과정을 식으로 정리하면 아래와 같습니다.
어텐션을 쓰는 이유는 은닉 상태만으로는 문장의 모든 정보를 완벽하게 전달하기 어렵기 때문입니다. 특히 문장이 길어질수록 이 문제는 더 심각해지기 때문에, 디코더의 time-step마다 현재 디코더의 은닉 상태에 따라 필요한 인코더의 정보에 접근하여 그 정보를 사용하기 위함입니다.
아래는 파이토치로 구현한 어텐션 모델이다.
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
def forward(self, input, hidden):
embedded = self.embedding(input).view(1, 1, -1)
output = embedded
output, hidden = self.gru(output, hidden)
return output, hidden
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
class DecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size):
super(DecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(output_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
output = self.embedding(input).view(1, 1, -1)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = self.softmax(self.out(output[0]))
return output, hidden
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
class AttnDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.max_length = max_length
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.out = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
attn_weights = F.softmax(
self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
encoder_outputs.unsqueeze(0))
output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
return output, hidden, attn_weights
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
참조 김기현의 자연어처리 딥러닝 캠프
https://wikidocs.net/22893
https://tutorials.pytorch.kr/intermediate/seq2seq_translation_tutorial.html?highlight=attention