마스킹| 패딩 마스크(Padding Mask), 룩 어헤드 마스킹(Look-ahead masking)

미남로그·2021년 9월 8일
1
post-thumbnail

참고 자료 출처: 딥러닝을 이용한 자연어 처리 입문

목차

  1. 마스킹(Masking)이란?
  2. 패딩(Padding)이란?
  3. 패딩 마스크(Padding Mask)
  4. 패딩 마스크 구현 방법
  5. 룩 어헤드 마스킹(Look-ahead masking)

마스킹(Masking)

마스킹이란 특정 값들을 가려서 실제 연산에 방해가 되지 않도록 하는 기법


패딩(Padding)

입력되는 문장의 모두 다를 것입니다. 이 다른 문장 길이를 조율해주기 위해 모든 문장의 길이를 동일하게 해주는 전처리 과정이 필요합니다.

짧은 문장과 긴 문장이 섞인 경우, 짧은 문장을 기준으로 연산을 해버리면 긴 문장에서는 일부 손실이 발생할 것입니다.

그래서 짧은 문장의 경우에는 숫자 0을 채워서 문장의 길이를 맞춰줘야 합니다.

케라스에서 pad_sequences()를 사용해서 전처리하기도 합니다.

그런데 여기서는 0을 채워주었지만 이게 실제로 의미 있는 값은 아닙니다. 실제 어텐션에서도 연산에서 제외할 필요가 있습니다.

숫자 0의 위치를 체크해주는 것이 바로 패딩 마스킹입니다.


패딩 마스크(Padding Mask)

앞서 구현했던 스케일드 닷 프로덕트 어텐션 함수 내부를 보면 mask라는 값을 인자로 받아, 이 mask 값에다 -1e9라는 아주 작은 음숫값을 곱해 어텐션 스코어 행렬에 더해주었습니다.

def scaled_dot_product_attention(query, key, value, mask):
... 중략 ...
    logits += (mask * -1e9) # 어텐션 스코어 행렬인 logits에 mask*-1e9 값을 더해주고 있다.
... 중략 ...

이건 입력 문장에

토큰이 있을 경우 어텐션에서 제외하기 위한 연산입니다. 예를 들어 <패드>가 포함된 입력 문장의 셀프 어텐션을 보겠습니다. 어텐션을 수행하고 어텐션 스코어 행렬을 얻는 과정은 아래와 같습니다.

<패드>는 사실 아무 의미가 없는 단어입니다. 그래서 트랜스포머에선 key의 경우 <패드> 토큰이 존재할 경우 유사도를 구하지 않도록 마스킹(Masking)을 해줍니다.

여기서 마스킹은 앞의 설명과 같이 어텐션에서 제외하기 위해 값을 가린다는 의미입니다.

어텐션 스코어 행렬에서 행에 해당하는 문장은 Query이고 열에 해당하는 문장은 key입니다. 그리고 key에 <패드>가 있는 경우에는 해당 열 전체를 마스킹합니다.

마스킹 방법은 어텐션 스코어 행렬의 마스킹 위치에 매우 작은 음숫값을 넣어주는 것이고, - 무한대에 가까운 값을 의미합니다.

현재 어텐션 스코어 함수는 소프트맥스 함수를 지나지 않은 상태이고, 어텐션 스코어 함수는 소프트맥스 함수를 지나 Value 행렬과 곱해질 것입니다.

그런데 현재 마스킹 위치에는 매우 작은 음숫값이 들어가 있으므로 어텐션 스코어 행렬이 소프트맥수 함수를 지난 후에는 0에 굉장히 가까운 값이 되어 유사도를 구할 때 <패드> 토큰이 반영되지 않게 됩니다.

위의 이미지가 예시입니다. 소프트맥스 함수를 지나면 각 행의 어텐션 가중치의 합은 1이 됩니다. 단어 <패드>의 경우는 0이 되어 유의미한 값을 갖고 있지 않습니다.


패딩 마스크 구현 방법

구현하는 방법은 입력된 정수 시퀀스에서 패딩 토큰의 인덱스인지, 아닌지를 판별하는 함수를 구현해야 합니다.

  • 정수 시퀀스에서 0인 경우에는 1로 변환
  • 정수 시퀀스에서 0이 아닌 경우에는 0으로 변환

def create_padding_mask(x):
   mask = tf.cast(tf.math.equal(x, 0), tf.float32)
   # (batch_size, 1, 1, key의 문장 길이)
   return mask[:, tf.newaxis, tf.newaxis, :]

임의의 정수 시퀀스 입력을 넣어 변환 결과를 보겠습니다.

print(create_padding_mask(tf.constant([[1, 21, 777, 0, 0]])))
tf.Tensor([[[[0. 0. 0. 1. 1.]]]], shape=(1, 1, 1, 5), dtype=float32)

위의 벡터를 통해 알 수 있는 건 1의 값을 가진 위치의 열을 어텐션 스코어 행렬에서 마스킹 용도로 사용할 수 있다는 것입니다.

  • 0, 0 → 마스킹 용도

위의 벡터를 스케일드 닷 프로덕트의 어텐션 인자로 전달하면 스케일드 닷 프로덕트 어텐션에서는 위 벡터에다가 매우 작은 음숫값인 -1e9를 더하고 이를 행렬에 더하여 해당 열을 전부 마스킹하게 됩니다.


룩 어헤드 마스킹(Look-ahead masking)

순환 신경망, RNN과 트랜스포머는 문장을 입력받을 때 입력받는 방법이 다릅니다.

RNN은 각 step마다 단어가 순서대로 입력으로 들어가는 구조입니다. 반면 트랜스포머의 경우는 문장 행렬을 만들어 한 번에 행렬 형태로 입력됩니다.

이 특징 때문에 추가적인 마스킹이 필요합니다.


RNN

http://torch.ch/blog/2016/07/25/nce.html

RNN의 Decoder 과정을 보겠습니다.

RNN은 구조상 다음 단어를 만들어갈 때, 자신보다 앞에 있는 단어들만 참고해서 다음 단어를 예측합니다.

  1. 첫 번째 Step
    현재까지의 입력: What → 출력: is

  2. 두 번째 Step
    현재까지의 입력: What is → 출력: the

  3. 세 번째 Step
    현재까지의 입력: What is the → 출력: problem


트랜스포머

트랜스포머는 문장 행렬로 들어가서 위치와 상관없이 모든 단어를 참고해서 다음 단어를 예측합니다. 이 문제를 해결하기 위해 다음에 나올 단어를 참고하지 않도록 가리는 기법이 룩 어헤드 마스킹 기법입니다.

이 기법은 어텐션을 수행할 때, Query 단어 뒤에 나오는 key 단어들에 대해 마스킹합니다.

https://youtu.be/xhY7m8QVKjo

빨간색 부분이 마스킹 된 부분입니다. 빨간색이 실제 어텐션 연산에서 가리는 역할을 합니다. 이 덕분에 현재 단어를 기준으로 이전 단어들하고만 유사도를 구할 수 있습니다.

행을 Query, 열을 Key로 표현된 행렬임을 감안하고 천천히 행렬을 살펴봅시다.

예를 들어, Query 단어가 '찾고'라면 그 행에는 < s > <나는>, <행복을>, <찾고>까지의 열만 보입니다. 그 뒤 열은 아예 빨간색으로 칠해져 있어서 유사도를 구할 수 없도록 해놓았습니다.

저 빨간색 부분을 마스킹 함수로 구현해보겠습니다.


def create_look_ahead_mask(x):
   seq_len = tf.shape(x)[1]
   look_ahead_mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
   padding_mask = create_padding_mask(x)
   return tf.maximum(lood_ahead_mask, padding_mask)
print(create_look_ahead_mask(tf.constant([[1, 2, 3, 4, 5]])))
tf.Tensor(
[[[[0. 1. 1. 1. 1.]
   [0. 0. 1. 1. 1.]
   [0. 0. 0. 1. 1.]
   [0. 0. 0. 0. 1.]
   [0. 0. 0. 0. 0.]]]], shape=(1, 1, 5, 5), dtype=float32)

대각선의 형태로 숫자 1이 채워지는 것을 볼 수 있습니다. 이 마스킹과 패딩 마스킹은 별개이므로, 이 마스킹만 수행했을 때 만약 숫자가 0인 단어가 있다면 이것도 패딩을 진행해야 합니다.

그래서 create_look_ahead_mask() 함수는 내부적으로 앞에 구현했던 패딩 마스크도 호출합니다.

숫자 0이 포함된 경우를 살펴봅시다.

print(create_look_ahead_mask(tf.constant([[0, 5, 1, 5, 5]])))
tf.Tensor(
[[[[1. 1. 1. 1. 1.]
   [1. 0. 1. 1. 1.]
   [1. 0. 0. 1. 1.]
   [1. 0. 0. 0. 1.]
   [1. 0. 0. 0. 0.]]]], shape=(1, 1, 5, 5), dtype=float32)

0이 들어가서 행렬 1열의 값이 모두 1로 바뀌었습니다.

profile
미남이 귀엽죠

0개의 댓글