paper link
presentation link
(이 글은 위 표기한 프레젠테이션 영상을 참고하여 작성했습니다.)
→ 연구 질문 : attention을 sub-quadratic 알고리즘으로 대체할 수 있을까?
→ x = hidden state
→ u = input sentence
→ y = output sentence
→ k = kernel
주요 특징
기존 transformer와의 성능 비교
quality에서 Perplexity 기준 꽤 큰 차이로 성능 떨어지는 것 확인 가능
Model | PPL(OWT) |
---|---|
Transformer | 20.6 |
S4D | 24.9 |
GSS | 24.0 |
효율성에 있어서 짧은 입력 받을 경우 transformer에 비해 효율성 떨어짐
- 약 2K길이까지는 비교적 느림
- 긴 입력에 있어서는 비교적 효율적임 (빠른 시간 내에 완료)
→ 연구 목표 : transformer와의 quality gap, efficiency gap을 없애자!
quality gap을 줄이기 위함
synthetic language
Evaluating SSMs with associative recall
Model | ACC |
---|---|
Transformer | 100.0 |
S4D | 38.8 |
GSS | 19.8 |
Hungry Hungry Hippos
Associative Recall을 위해 설계됨
설계
2개의 SSM레이어 쌓기
Shift SSM : local lookup across sequence
→ input에 대한 아주 짧은 local convolution으로 이해 가능
Diag SSM : global memory
→ gating mechanism을 통해 기억할 정보를 취사선택 가능
(기존에 알고있던 Masking과 유사한 기능인 듯..)
Multiplicative interactions between outputs
H3가 어떻게 associative recall문제를 해결하는가
예시
c를 맞추기 위해 Q, K 는 c가 value에 들어올때 만 1, 1로 활성화 됨
3번째 블록에서 c가 value에 입력
4번째 블록에서 2가 value에 입력
5번째 블록에서 b가 value로 입력
7번째 블록에서 c가 value로 입력
각각의 H3 레이어가 single key값에 대한 정보를 저장하게 됨
Quality Gap 줄이기
acc
Model | ACC |
---|---|
Transformer | 100.0 |
S4D | 38.8 |
GSS | 19.8 |
H3 | 98.4 |
perplexity
Model | PPL(OWT) |
---|---|
Transformer | 20.6 |
S4D | 24.9 |
GSS | 24.0 |
H3 | 21.0 |
H3 + 2attn (hybrid) | 19.6 |
efficiency gap을 줄이기 위함
long convolution (long SSMs) with FFT convolution
FFT convolution (Fast Fourier Transform conv)
# y = u * k
u_f = torch.fft.fft(u) # input signal
k_f = torch.fft.fft(k) # convolution kernel
y_f = u_f * k_f. # pointwise multiply
y = torch.fft.ifft(y_f)# inverse FFT
- O(NlogN)의 시간복잡도
- naive attention보다 느린 속도
속도를 높이기 위해서
1) kernel fusion
y = fused_fft_conv(u,k)
- 위의 모든 연산 한번에
- SRAM 환경에 적합
- 적은 memory I/O
2) Block FFT
tensor decomposition of FFT operation
: FFT of length N
: shorter FFT
: diagonal matrix
: permutation
→ FFT 연산을 행렬 연산으로 변형
3) State Passing : long sequence로 scaling
4K, 8K보다 긴 입력을 받아서 연산하는 과정
긴 입력을 4K, 8K 길이로 잘라 chunk를 순서대로 입력하기 (recurrent view)
→ 첫번째 chunk에 대해 output 계산하는 과정에서 SSM state 또한 update됨
→ update된 SSM state를 다음 단계에 넘겨 다음 chunk의 output을 계산하는 과정에서 활용
3가지 방식 활용한 결과
논문 중간에 이러한 내용이 있다.
To preserve the sequence history, HiPPO [24] projects the history on a basis of orthogonal polynomials, which translates to having SSMs whose A, B matrices are initialized to some special matrices.
- 여기서 말하는 A, B는 SSM 연산 중 hidden state와 input sentence에 곱해지는 A, B 파라미터를 의미
직교 다항식(orthogonal polynomials)을 기반으로 시퀀스의 history를 투영하는 과정에서 SSM을 특정 행렬로 초기화하는 방식을 본따온 듯..!