[NLP] Gate를 활용해 RNN의 단점을 극복한 모델: LSTM, GRU

nkw011·2022년 7월 20일
0
post-thumbnail

1. Short-term Memory

RNN은 단기 기억을 가진다는 단점이 존재한다. 길이가 긴 문장에 대해서 학습이 효과적이지 않다는 뜻이다. 이는 RNN이 가진 구조의 특징 때문이다.

RNN은 순전파시 반복적으로 동일한 파라미터를 곱한다. 따라서 길이가 길어질수록 역전파할 때 vanishing gradient 현상을 발생시킨다. 점점 작아지는 gradient로 인해 모델 파라미터가 업데이트되지 않고 학습이 제대로 이루어지지 않기 때문에 문장 내에서 어떤 정보를 보았는지 기억하지 못하게된다.

  • 모델 파라미터가 업데이트되지 않는다면 정확한 학습을 진행하지 못하게 된다. label과 예측된 값의 비교를 통해 loss를 구하게 되고 이를 모델 파라미터에 반영해 올바른 학습이 되도록 해야되는데 모델 파라미터가 업데이트 되지 않으면 처음 잘못된 예측을 계속해서 반복하게된다.

LSTM은 Long Short-Term Memory의 약자로 RNN이 가진 short-term memory 단점을 보완한 모델이다. 그래서 이름도 short-term memeory를 길게(long) 늘렸다고 해서 Long Short-Term Memory이다.

어떤 방식을 통해 해당 문제를 보완했는지 살펴보자.

2. LSTM의 Core Idea

LSTM은 long-term dependency 문제를 극복하기위해 cell state vector와 gate라는 구조를 추가하였다.

cell state vector

RNN은 입력으로 input vector, hidden state vector 총 2개의 벡터 정보를 받는다.

LSTM은 여기에 1가지 벡터를 더해 총 3개의 벡터 정보를 받는다.

  • input vector
  • hidden state vector
  • cell state vector

cell state vector와 hidden state vector 모두 재귀적인 형태로 다음 시점의 입력으로 들어간다.

gate

모델 내부로 흘러들어오는 정보를 통제한다. 학습이 진행되면서 어떤 정보를 기억해야하는지 또는 잊어버리야할지 선택하는 역할을 하게된다.

3가지의 gate가 존재하며 맡고 있는 역할은 다음과 같다.

  • input gate: 입력으로 들어오는 정보의 양을 통제해 현재 시점에서 정보를 얼마나 기억할지 정한다.
  • forget gate: 이전 시점 cell state vector의 정보를 얼마나 기억할지 정한다. (cell state의 정보를 통제한다.)
  • output gate: 현재 시점에서 생산된 cell state vector의 정보를 얼마나 출력으로 내보낼지 선택한다.

gate는 모두 input vector와 hidden state vector를 통해 만들게된다.

  • input gate: i=σ(Wi[xt,ht1]+bi)i = \sigma(W_i\cdot[x_t, h_{t-1}] + b_i)
  • forget gate: f=σ(Wf[xt,ht1]+bf)f = \sigma(W_f\cdot[x_t, h_{t-1}] + b_f)
  • output gate: o=σ(Wo[xt,ht1]+bo)o = \sigma(W_o\cdot[x_t, h_{t-1}] + b_o)

σ()\sigma()는 sigmoid 함수로 0 ~ 1 사이의 값을 출력으로 내보낸다. 선형변환을 통해 만들어진 벡터의 각 원소를 0 ~ 1 사이로 바꿔 gate의 출력값이 곱해지는 벡터의 차원마다 정보를 얼마나 기억할지 통제한다고 보면된다.

  • 0 ~ 1 사이이기 때문에 확률로 생각할 수 있다.

3. LSTM 구조

LSTM은 매 시점마다 총 3가지의 벡터가 입력된다.

  • xtx_t: 현재 시점 input vector
  • ht1h_{t-1}: 이전 시점 hidden-state vector
  • ct1c_{t-1}: 이전 시점 cell-state vector
  • 현재 시점에 사용되는 gate: iti_t, ftf_t, oto_t

현재 시점 cell state vector ctc_t 생성

먼저 forget gate를 활용하여 이전 시점 cell state vector의 정보를 얼마나 기억할지 결정한다.

  • forget gate: ft=σ(Wf[xt,ht1]+bf)f_t = \sigma(W_f\cdot[x_t, h_{t-1}] + b_f)
  • forget gate를 이전 시점 cell state vector에 원소별 곱셈(element-wise product)을 한다.
    • ftct1f_t \odot c_{t-1}

forget gate를 통과한 cell state vector에 새롭게 더할 정보를 생산하고 input gate를 활용하여 새롭게 생산된 정보를 얼마나 기억할지 결정한다.

  • ct~=tanh(Wc[xt,ht1]+bc)\tilde{c_t} = tanh(W_c \cdot [x_t, h_{t-1} ] + b_c)
  • input gate: it=σ(Wi[xt,ht1]+bi)i_t = \sigma(W_i\cdot[x_t, h_{t-1}] + b_i)
  • input gate를 새롭게 생산된 만들어진 벡터에 원소별 곱셈(element-wise product)한다.
    • itct~i_t \odot \tilde{c_t}

forget gate를 통과한 이전 시점 cell-state vector에 input gate를 통과한 새로운 정보를 더하여 현재 시점 cell state vector를 만든다.

  • ct=ftct1+itct~c_t = f_t \odot c_{t-1}+ i_t \odot \tilde{c_t}

생성된 cell state vector는 hidden state vector를 만들 때 쓰이고 다음 시점으로 넘어간다.

현재 시점 hidden state vector ht1h_{t-1} 생성

생성된 현재 시점 cell state vector를 output gate에 통과시켜 hidden state vector를 생성한다.

  • ot=σ(Wo[ht1,xt]+bo)o_t = \sigma(W_o\cdot[h_{t-1},x_t]+b_o)
  • ht=ottanh(ct)h_t = o_t \cdot tanh(c_t)

생성된 hidden state vector는 다음 시점으로 넘어가고, output이 되거나 next layer으로 이동한다.

cell state vector와 hidden state vector의 차이점

hidden state vector는 cell state vector에 한 번의 연산을 더해 만든 정보임을 알 수 있다.

이러한 과정의 의미는 다음과 같다고 볼 수 있다.

  • cell state에서는 기억해야할 모든 정보를 기억한다.
  • cell state에 저장된 정보 중에서 현재 시점에 필요한 정보를 hidden state에서 필터링한 것이다.

예를 들어 큰 따옴표로 시작되는 다음과 같은 문자열 “hell … 에서 를 예측해야한다고 해보자

  • 현재 _ 위치의 문자를 예측하는데 “(큰 따옴표) 는 필요하지 않다. → hidden state
  • 하지만 문자열이 끝나는 시점에 “(큰 따옴표) 를 닫아야하므로 “(큰 따옴표) 가 있다는 것을 기억해야한다. → cell state

Gradient Vanishing 현상을 어떻게 해결할까?

cell state에서 일어나는 역전파 과정을 살펴보면 그 이유를 파악할 수 있다.

cell state에서 역전파가 일어날 때 덧셈과 원소별 곱(element-wise product)를 지나게 된다.

  • gradient 계산과정에서 덧셈은 이전에 계산된 gradient를 변형하지 않고 그대로 흘려보내주는 역할을 한다.
  • 매 시점마다 다른 gate 값을 이용해 원소별 곱(element-wise product)을 하기 때문에 동일한 값을 계속 곱하지 않게 된다.
    • 만약 forget gate의 결과가 작다면 gate가 잊어야한다고 판단을 한 것이기 때문에 역전파할 때도 작은 값이 곱해져 기울기가 작아진다.
    • 반대로 forget gate의 값이 크다면 기억해야할 정보라고 판단을 한 것이기 때문에 역전파할 때도 큰 값이 곱해져 기울기가 작아지지 않은채로 전달된다.

4. GRU(Gated Recurrent Unit)

LSTM이 long-term dependency 문제를 보완하고 좋은 성능을 내었지만 모델 파라미터가 많기 때문에 계산이 오래 걸린다는 단점이 존재한다.

GRU는 계산 시간을 줄이기 위해 성능은 거의 유지하면서 LSTM의 구조를 단순화한 모델이라고 볼 수 있다.

LSTM과 차이점

  • GRU는 hidden state vector만 사용한다.
  • gate의 갯수를 2개로 줄였다.
    • update gate (zt)z_t): hidden state vector를 update하는 역할을 한다.
    • reset gate (rtr_t): 이전 시점 hidden state vector를 얼마나 무시할 지 결정한다.

GRU 구조

GRU는 매 시점마다 2개의 벡터가 입력된다.

  • xtx_t: input vector
  • ht1h_{t-1}: 이전 시점 hidden state vector

먼저 reset gate를 이용하여 이전 시점의 정보를 얼마나 무시할 지 결정한다.

  • rt=σ(Wr[xt,ht1]+br)r_t = \sigma(W_r\cdot[x_t, h_{t-1}] + b_r)
  • reset gate를 이전 시점 hidden state vector에 원소별 곱(element-wise product)한다.
    • rtht1r_t \odot h_{t-1}

이후 input vector와 같이 활용하여 현재 시점 hidden state vector의 후보를 만든다.

  • ht~=tanh(W[rtht1,xt]+b)\tilde{h_t} = tanh(W \cdot [r_t\odot h_{t-1}, x_t] + b)

update gate를 활용하여 hidden state vector를 만든다.

  • update gate: zt=σ(Wz[xt,ht1])z_t = \sigma(W_z\cdot[x_t,h_{t-1}])
  • ht=(1zt)ht1+ztht~h_t = (1-z_t)\odot h_{t-1} + z_t \odot \tilde{h_t}

5. 실습

profile
Deep Dive into Development (GitHub Blog: https://nkw011.github.io/)

0개의 댓글