S4

Seojin Kim·2024년 8월 3일

SSM

목록 보기
3/8

Abstract

  • SSM 에서 제대로 A를 고르기만 한다면 Long range dependency를 수학적, 실험적으로 잘 handle할 수 있다.
  • S4 : Structured State Space Sequence model = 새로운 변수 정의를 통해 보다 효율적으로 모델링이 가능함을 보일 것이다
    • low rank correction : stable diagonalization → Cauchy Kernel을 활용하여 효율적 계산이 가능하게

Introduction

  • Long range dependency : 제대로 해결하는 모델이 없음 (랜덤 추측이 제일 나을지경)
  • LSSL : 실용적으로 사용하기에는 메모리 문제, 계산 효율성 문제가 있음
  • 따라서 S4는 이러한 병목현상을 해결할 방법을 제시하고자 함
  • Structured State matrices A\mathbf{A} 를 reparameterize by low rank and normal term
  • coefficient space expansion을 위해 frequency domain에서 함수를 생성 = multipole like evaluation
  • Eventually resulting Cauchy Kernel!
    • Cauchy kernel : 두 벡터 사이의 거리를 고려, 커널 폭이 작을 수록 더 가까운 벡터에 큰 가중치

      K(x,y)=11+xy2γ2K(x,y)={1 \over {1+{∥x−y∥^2 \over γ^2}}}
  • General purpose sequential modeling
    • 어느 한 도메인에 국한된 것이 아니라 여러 분야에 널리 사용할 수 있음
    • Deep SSM : LRD + cont. time + convolutional + recurrent
    • large scale generative modeling
    • fast autoregressive generation
    • sampling resolution change
    • learning with weaker inductive bias

Background

State-Space model : continuous time latent state model

스크린샷 2024-04-21 오후 7.48.09.png

  • 1-D input signal → N-D latent state x(t) → 1-D output signal y(t)
  • Hidden Markov Model! (HMM)
  • A, B, C, D 를 gradient descent에서 제대로 잘! 배우자.
    • D=0으로 가정할 것임 (Skip connection으로 해석되므로 쉽게 implement가능함)

HiPPO : addressing long range dependencies

스크린샷 2024-04-21 오후 7.52.10.png

  • Linear ODE가 exponentially solved 될 수 있음 → 이를 해결하고자 hIppo 를 도입하려함. : history를 기억할 수 있도록

Discrete-time SSM : Recurrent representation

  • 연속 함수 대신 이산적인 output sequence를 얻기 위해서는 step size에 대해서 표현되어야함 = sampling implicit underlying continuous signal u(t), where uk=u(kΔ)u_k = u(k\Delta) 스크린샷 2024-04-21 오후 7.54.12.png
  • Bilinear method에 따라 A를 근사해 표현

Training SSMs : The Convolutional Representation

  • Linear time-invariant SSm ~ Cont. Convolution 간의 관계
  • Initial state = 0 → convolution kernel ⇒ FFT를 통해 바로 계산이 가능함
    • how? 스크린샷 2024-04-21 오후 8.00.15.png
  • Kˉ=\bar{\mathbf{K}} = SSM convolution kernel 스크린샷 2024-04-21 오후 7.56.06.png

Method : Structured State Spaces (S4)

  • parameterization of S4 and show how to efficiently compute all views of SSM : continuous representation, recurrent representation, convolutional representation

Diagonalization

  • bottleneck of discrete-time SSM = repeated matrix multiplication by Aˉ\bar{\mathbf{A}}

스크린샷 2024-04-21 오후 8.04.07.png

  • x=Vxˉx = V\bar{x} 라고 정의하면 같은 operator에 대해 계산됨, V = 기저 변환 행렬
  • 그럼 여기서 A가 좀 더 간단한 행렬이면 보다 빠른 계산이 가능해질 것!
  • 만약 A가 대각행렬이라면 커널 행렬은 Vandermonde product 가 될 것 (훨씬 작은 연산량)
  • 그러나 이렇게 naive한 행렬 변환은 불가 : N이 exponentially large → infeasible! 스크린샷 2024-04-21 오후 8.07.56.png
  • 따라서 대신 lemma 2에 따른 대각화를 진행한다.

The S4 Parameterization : Normal Plus Low-Rank

  • well conditioned matrices V 로 대각화 = 이상적인 시나리오는 전체 A가 perfectly conditinoed matrix로 정리가 되는것
  • = Spectral Theorem of Linear Algebra : normal matrices!
  • 이는 Hippo matrix로는 분해가 안됨.. 그러나, normal and low-rank matrix의 합으로는 분해할 수 있음 스크린샷 2024-04-21 오후 8.20.26.png
    • 다만, 여전히 매우 느리고 최적화가 어려움 → 해결하고자 세가지 기술을 적용할 것
    1. 커널 행렬을 직접 계산하는 대신, 스펙트럼을 truncated generating function j=0L1Kˉjζj\sum_{j=0}^{L-1} \mathbf{\bar{K}}_j \zeta^j at the roots of unity ζ\zeta ← inverse FFT로 찾을 수 있음
    2. Generatinf function ~ matrix resolvent ~ matrix inverse instead of power : Woodbury Identity applied → (A+PQ)1(A + PQ^*) ^{-1} reducing to A1A^{-1}
    3. DIagonal matrix case = Cauchy kernel : 1ωjζk1 \over {\omega_j - \zeta_k}

S4 algorithms and computational complexity

스크린샷 2024-04-21 오후 8.21.31.png

스크린샷 2024-04-21 오후 8.21.42.png

스크린샷 2024-04-21 오후 8.22.14.png

…. 그냥 결과가 ㅁ든 도메인에서다 좋음 .. 천재?

profile
M.S Student @ KAIST GSAI

0개의 댓글