LSTM(장단기 메모리)

AI Scientist를 목표로!·2022년 8월 3일
0

NLP

목록 보기
3/5

이전 포스팅한 RNN의 문제점은 기울기 소실 문제로 인해 시점이 길어질수록 앞의 정보가 뒤로 충분히 전달되지 못하는 문제가 발생한다는 것입니다.

이런 문제를 해결하기 위해 나온것이 LSTM 입니다.

LSTM 이란?

전통적인 RNN의 단점을 보완한 LSTM(장단기 메모리)은 은닉층의 메모리 셀에 입력 게이트, 망각 게이트, 출력 게이트를 추가하여 불필요한 기억을 지우고, 기억해야할 것들을 정합니다.

RNN에서 cell state라는 값이 추가 되었습니다.

아래 그림에서 t시점의 cell state를 CtC_t로 표현 하고 있습니다.

cell state는 위 이미지의 굵은 선입니다.
이전 시점의 cell state가 다음 시점의 cell state를 구하기 위한 입력 값으로 사용 됩니다.

hidden state의 값과 cell state의 값을 구하기 위해서 새로 추가 된 3개의 게이트를 사용합니다.

각 게이트는 망각 게이트, 입력 게이트, 출력 게이트라고 부르며, 이 3개의 게이트에는 공통적으로 시그모이드 함수가 존재합니다.

시그모이드 함수를 지나면 0과 1사이의 값이 나오게 되는데 이 값들을 가지고 게이트를 조절합니다.

LSTM 구조

(1) 전 셀의 출력에 기억 라인이 추가됨(ht1h_{t-1}Ct1C_{t-1})

단순 RNN에서는 하나였던 이전 셀의 정보 전달이, 출력(Recurrent) 외에 기억(Memory)이 추가되어 2라인으로 되어 있습니다.

Recurrent 쪽이 RNN과 같은 단기 기억이고, Memory가 장기 기억이라고 생각하시면 됩니다.

LSTM은 단기와 장기를 연관시키면서 각각 다른 라인에서 기억을 보존하고 있는 것입니다.

(2) 이전 셀의 출력(Recurrent)과 입력의 합류(ht1h_{t-1}XtX_t)

이전 셀의 출력 ht1h_{t-1}(단기 기억)과 지금 셀의 입력 XtX_t가 합쳐 집니다.
합쳐진 신호는 4개의 라인에 복사됩니다.

그 결과는 '내가 좋아하는'이라는 단기 기억에 입력값 '사과'가 더해진 것입니다
(이것은 이전 RNN과 동일).

(3) 망각 게이트(ft의 출력)

가장 윗 라인은 망각 게이트입니다.
이것은 이전 셀에서의 장기 기억 하나하나에 대해 σ(시그모이드 함수)에서 나온 0~1 사이의 값 ftf_t로 정보의 취사선택을 하는 것입니다.

단기 기억 ht1h_{t-1}와 입력 XtX_t로 '내가 사랑해 마지않는 사과'까지 인식한 시점(t)에서 장기 기억 속의 '포도'는 중요하지 않다고 판단했을 때,

σ(시그모이드)의 출력 ftf_t는 0 근처의 값이 되어, 이 기억을 망각합니다.

한편, "맛있는"이라는 정보는 중요한 것 같아서 ftf_t는 1로 그대로 남아 있습니다.

RNN이 과거의 모든 정보를 이용하려고 하면 계산량이 폭발하겠지만,
망각 게이트에 의해 원치 않는 정보를 버림으로써 폭발을 방지합니다.

(4) 입력 게이트(CtC_t'iti_t)

단기 기억 ht1h_{t-1}과 입력 XtX_t로 합산된 입력 데이터를 장기 보존용으로 변환한 후 어떤 신호를 어느 정도의 무게로 장기 기억에 저장할지 제어합니다.

이것은 두 단계로 처리됩니다.

① tanh에 의한 변환(CtC_t' 를 출력)

들어온 정보를 그대로 흘려보내는 것이 아니라 요점을 맞춘 단적인 형태로 만드는 쪽이 정보량을 줄일 수 있고, 사용하기 좋습니다.

tanh은 내부에서 추가되는 후보 값을 결정하는 데 유용한 함수입니다.
예를 들어, '사랑해 마지않는'을 '좋아하는'이라는 후보로 바꾸는 식입니다.
이렇게 간단히 변환되어 CtC_t'가 출력됩니다.

② 입력 게이트(iti_t)에 의한 선별

LSTM은 통시적 오차 역전파(BPTT)에 의해 가중치를 조정합니다.
보통의 오차 역전파는 입력 XtX_t의 weight 조절이지만, 통시적 오차 역전파는 이외에도 이전 셀에서의 단기 기억 ht1h_{t-1}의 정보에 영향을 받습니다.

따라서 ht1h_{t-1}에서 들어오는 무관한 정보에 의해 가중치가 잘못 업데이트되는 것을 방지하기 위해 입력 게이트가 필요한 오차 신호만 제대로 전달하도록 제어하고 있습니다.

ht1+Xth_{t-1} + X_t로 만들어진 '내가 사랑해 마지않는 사과'라는 정보 중에서 입력 게이트 σ(시그모이드 함수)가 남겨둘 것과 흘려보낼 것을 선별합니다.

(5) 출력 게이트(oto_t를 출력)

hth_t는 단기 기억의 출력입니다. 위와 같은 과정을 통해 장기 기억에 단기 기억이 더해져 선별된 값(장기 기억의 출력 CtC_t)에서 단기 기억에 관한 부분만 출력합니다.
여기서도 아까와 마찬가지로 2단계로 처리됩니다.

① tanh에 의한 변환

tanh의 입력은 이전 셀에서의 장기 기억 Ct1C_{t-1}에 입력 XtX_ t를 변환한 단기 기억 CtC_t'을 더한 것입니다.
각각 망각 게이트 및 입력 게이트로 취사선택되고 있습니다.
이것을 그대로 장기 기억으로 출력하는 것이 CtC_t이지만, 거기에 포함된 단기 기억 부분도 장기 기억과 함께 포함시킴으로 인해 단기 기억만 있을 때보다 이용하기 쉽게 변환할 수 있습니다.

예를 들어, 단기 기억이 '나의 동생이 좋아하는 사과는'이었다고 합시다.
이 경우 장기 기억 나의 동생이 철수라는 중요한 요소가 있다면, 단기 기억을 좀 더 명확 '철수가 좋아하는 사과'로 변환해 버립니다.

② 단기 기억의 취사선택

입력 게이트가 스스로 셀을 보호했듯이 출력 게이트도 다음 셀에 대한 나쁜 정보의 전파를 방지합니다.
다음 셀을 활성화하기 위한 가중치 hth_t를 업데이트할 때, 관련 정보를 흘려 나쁜 영향을 주지 않도록 해야 합니다.
출력 게이트 σ(시그모이드 함수)에 의해 0~1의 범위에서 OtO_t가 출력되고, 단기 기억 출력 hth_t에 필요한 신호만 제대로 전달하도록 제어하고 있습니다.

이번에는 입력 게이트에서 "나의"라는 말이 이미 잘렸기 때문에 출력 게이트에서는 특별히 자를 말이 없습니다.
이 입력에서 출력으로도 이중으로 게이트 체크하여 관련 정보가 흐르지 않도록 철저히 하고 있습니다.
지금까지의 처리 결과 이 셀에서 '내가 좋아하는 사과'라는 정보가 hth_t에 출력된 것입니다.

profile
딥러닝 지식의 백지에서 깜지까지

0개의 댓글