Mamba는 SSM(State Space Models) 계열의 아이디어를 확장한 모델로, 긴 시퀀스 데이터를 다룰 때 효율적입니다.
SSM은 입력과 현재 상태(state)로부터 다음 상태를 계산하는 모델입니다.


이후 매 시간마다 h(t)를 사용해 y(t)를 계산합니다. 고전적인 SSM에서는 A, B, C가 고정된 매개변수로 학습됩니다.

Mamba의 핵심은 Selective SSM입니다.
또한 순차적 계산을 GPU에서 효율적으로 수행하기 위해, Mamba는 chunkscan(병렬 prefix-scan) 알고리즘을 사용하여 전체 연산을 병렬화하고 선형 시간으로 수행합니다.
Mamba의 수식을 풀어서 쓰면 다음과 같습니다.:
이 수식을 보면 앞에서부터 곱과 합을 누적하는 과정이라는 걸 알 수 있습니다. 이건 누적합(prefix sum)과 누적곱(prefix product) 문제로 바꿀 수 있습니다.
병렬 알고리즘에서는 이런 누적합/곱을 빠르게 계산하는 prefix-scan(스캔 연산)이 있습니다.
병렬 scan 알고리즘은 크게 두 단계로 나눕니다.

루트에는 0을 놓고 시작.

왼쪽 자식은 "부모의 prefix"를 그대로 받음.

오른쪽 자식은 "부모의 prefix + 왼쪽 자식 합"을 받음. 이 과정을 트리 아래로 내려가며 반복합니다.

트리의 노드 깊이만큼 병렬 시간 처리 시간이 걸리기 때문에 O(logN)의 시간복잡도가 걸립니다.
Prefix-scan 자체는 log N 단계라도 코스트가 커질 수 있고, GPU는 작은 블록 단위 연산에 최적화되어 있기 때문에 그대로 사용할 수 없습니다.
그래서 Mamba는 chunkscan이라는 변형을 씁니다. 핵심은 시퀀스를 chunk 단위(예: 128 토큰)로 잘라서 scan을 두 번 하는 것입니다.
chunkscan 타일 연산은 matmul과 다르게 tile간 결과에 의존성이 생깁니다.
mamba에서 chunkscan 커널 구현을 보면 다음과 같습니다.
각 위치 l의 출력 out[l]을 수식으로 표현하면 다음과 같습니다.
첫번째 항: intra-chunk scan
두번째 항: 이전 chunk에서 넘어온 상태의 기여
세번째 항: 잔차 게이트
이들을 모두 곱한 뒤 s ≤ l (즉 과거만) 모아서 더합니다.
코드로 보면 다음과 같습니다.
def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
_, _, ngroups, _, _ = cb.shape
batch, seqlen, nheads, headdim = x.shape
# _, _, ngroups, dstate = B.shape
# assert B.shape == (batch, seqlen, ngroups, dstate)
_, _, nchunks, chunk_size = dt.shape
assert seqlen == nchunks * chunk_size
C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups)
# exp(dA[l] - dA[s])
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
decay = torch.exp(dt_segment_sum)
scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
# s ≤ l 덧셈을 위한 마스크
causal_mask = torch.tril(
torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(
C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
out = out + out_prev
out = rearrange(out, "b c l h p -> b (c l) h p")
if D is not None:
if D.dim() == 1:
D = rearrange(D, "h -> h 1")
out = out + x * D
return out
모든 선형 대수 연산(dot, matmul, batch-matmul, transpose, trace, ...)을 통일된 표기로 표현할 수 있는 도구입니다. 같은 인덱스끼리는 곱한 뒤 sum하고, 남은 인덱스는 그대로 출력으로 남습니다.