LSTM

안녕하세요·2023년 11월 23일
0

Pytorch_NLP_Cookbook

목록 보기
2/3
post-thumbnail

1. Review - RNN

RNN 정리글

  • Task에 따라 여러 RNN 알고리즘이 파생된다
    e.g. 인코더/디코더 프레임워크
  • xtx_{t}ht1h_{t-1}을 combine
    1. Uxt+Wht1Ux_{t}+Wh_{t-1} : 각 ww 곱한 후 element wise
    2. W[ht,xt]W[h_{t},x_{t}] : concentrate한 후 ww 곱함
      → algebraic 관점에서는 동일
  • combine한 결과에 VV 매트릭스 곱해서 yy space로 변환
  • 한계 : 역전파 단계에서 tanhtanh를 여러번 미분 → 기울기가 빠르게 0으로 수렴
    • sequence가 길수록 학습이 안됨

그 대안으로 나온 방안이 LSTM

2. LSTM의 핵심 아이디어

2.1. Standard RNN의 문제점

한 채널에 두 가지 역할을 담아야 한다
1. 이전 step output
2. 다음 step에 넘어가는 정보

2.2. 대안

Cell state(중요한 정보만 흘러가는 information flow)를 추가하자!

어떻게 중요한 정보만 남길 것인가?

Gate를 사용하자!

Hidden state(각 step output)는 Cell state를 적절히 가공해서 내보내자!
(RNN은 hidden state 채널만 존재)

2.3. LSTM Flow

  1. Ct1C_{t-1}에서 불필요한 정보 삭제
  2. 새로운 input (xt,ht1x_{t},h_{t-1})를 보고 Ct1C_{t-1}의 중요한 정보 업데이트 → CtC_{t}
  3. CtC_{t} 가공하여 hth_{t} 생성
  4. 다음 step으로 CtC_{t}hth_{t} 전달

2.4. 핵심 개념 : Gate

  • 목적 : coefficient로 각각의 정보 중요도 계산하기 위함
  • 형태 : 0~1 사이의 값으로 이루어진 벡터
  • 필터링 단계에 따라 다른 게이트 사용 (파라미터 구별)
    • forget gate : Ct1C_{t-1}에서 불필요한 정보 필터링
    • input gate : Ct~\tilde{C_{t}}(임시 cell state)에서 중요한 정보만 필터링
    • output gate : CtC_{t}를 가공하여 hth_{t}로 만듦

3. 수식

3.1. Gate

gt=σ(Wgvinput)g_{t} = \sigma(W_{g} \cdot v_{input})

  • σ\sigma : 0~1의 값으로 Transformation output을 매핑
  • vinputv_{input} : input vector
  • WgW_{g} : Linear Transformation
    (Cell state의 dimension으로 input vector를 바꾸는 값)

Ct=gtCtC'_{t} = g_{t} \cdot C_{t}

  • 원하는 정보만 남긴 값 = 게이트 \cdot Cell state

3.2. 전체 Flow (2.3. 참조)

  1. 정보 필터링 : forget gate 구성
    ft=σ(Wf[ht1,xt]+bf)f_{t} = \sigma(W_f \cdot [h_{t-1}, x_{t}] + b_{f})

  2. 정보 업데이트 : input gate 구성
    it=σ(Wi[ht1,xt]+bi)i_{t} = \sigma(W_i \cdot [h_{t-1}, x_{t}] + b_{i})

    Ct~=tanh(WC[ht1,xt]+bC)\tilde{C_{t}} = tanh(W_C \cdot [h_{t-1}, x_{t}] + b_{C}) (바닐라 RNN과 동일)

  3. Cell state
    Ct=ftCt1+itCt~C_{t} = f_{t} * C_{t-1} + i_{t} * \tilde{C_{t}}

  4. 가공 : output gate 구성
    ot=σ(Wo[ht1,xt]+bo)o_{t} = \sigma(W_o \cdot [h_{t-1}, x_{t}] + b_{o})

    ht=ottanh(Ct)h_{t} = o_{t} * tanh(C_{t})

  5. 전달

4. 결론

4.1. vanishing gradient problem을 해결하였는가?

정보가 한참 나중 seq에서 사용될 때

  • LSTM
    • Cell state: non-linear activation function이 없음
      → 정보가 지워지지 않는 한 바로 input으로 들어가 h 만드는 데 사용
  • RNN : 계속 tanhtanh가 중첩

∴ 실제로 Sequential Data를 다룰 때 LSTM을 더 많이 쓴다!

4.2. 한계

flow의 각 단계가 중복된다
→ 더 간단하게 forget gate, input gate, output gate를 해결할 수 없을까?

LSTM보다 더 효율적인 방식 : GRU

참고 & 이미지 출처 : 2019 KAIST 딥러닝 홀로서기 세미나

profile
반갑습니다

0개의 댓글