SliceMoE:Routing Embedding Slices Instead of Tokens for Fine-Grained and Balanced Transformer Scaling

하임·2026년 1월 9일

MoE

목록 보기
14/14

https://www.arxiv.org/pdf/2510.04286


1. 서론(Introduction) 정리

1.1 기존 Token-MoE의 한계

Sparse Mixture-of-Experts(MoE) 레이어는 각 토큰을 소수의 FFN expert에게만 보내는 방식으로, 큰 파라미터 수 대비 계산량을 줄여서 효율적으로 스케일링하는 방법으로 쓰이고 있음(Switch Transformer 등).

하지만 토큰 단위 라우팅(token-level routing)에는 실무적으로 계속 등장하는 문제가 있음:

  1. 부하 불균형(load imbalance)
    • 한 expert가 “인기 토큰”을 너무 많이 받아서 과부하되고,
    • 다른 expert는 거의 쓰이지 않는 under-utilization 문제가 생김.
    • 이는 레이턴시 스파이크 및 자원 낭비로 이어짐.
  2. 전문화(specialization)의 한계
    • 한 expert가 항상 전체 토큰 임베딩(전체 d 차원)을 처리해야 하기 때문에,
    • 비교적 좁은 특징 하위 공간(sub-space)에만 특화되기 어렵다.
    • 결과적으로 “모듈식 전문가 구조”의 장점이 약해진다.

1.2 저자들의 가설: “토큰 임베딩 안에도 서로 다른 정보가 섞여 있다”

저자들은 다음과 같이 가정함:

하나의 d차원 토큰 임베딩 벡터 안에서도,

연속된 일부 차원(segment) 은 문법/형태소 정보(syntax),

다른 일부 차원은 의미 정보(semantics) 같은 서로 다른 성격의 정보를 담고 있을 수 있다.

즉, “토큰 전체”가 아니라 “벡터의 일부 구간(slice)” 단위로 라우팅할 수 있다면:

  • 라우터가 더 미세한 단위로 선택적 계산(conditional computation) 을 할 수 있고
  • expert들은 특정 하위 공간(예: 정문법, 의미, 도메인 특징)에 더 잘 전문화될 수 있다.

1.3 SliceMoE의 핵심 아이디어

그래서 논문은 SliceMoE라는 MoE 아키텍처를 제안함:

  • 각 토큰의 hidden vector (hRdh \in \mathbb{R}^d)를 S개의 연속적인 슬라이스로 나눈다.
    h=[h(1),h(2),,h(S)],h(s)Rd/Sh = [h^{(1)}, h^{(2)}, \dots, h^{(S)}],\quad h^{(s)} \in \mathbb{R}^{d/S}
  • 각 슬라이스마다 공유된(shared) slice router로부터 top-k expert를 선택한다.
  • 선택된 expert들은 자신에게 배정된 슬라이스만 처리하고, 모든 슬라이스 결과를 다시 합쳐(token-wise reassemble) 다음 Transformer 레이어로 전달.

이 접근으로 얻는 장점:

  1. 부드러운 부하 분산(smoother load distribution)
    • 토큰 하나당 S개의 독립적인 routing 결정을 하므로,
    • 하나의 expert가 특정 토큰에만 집중되는 것이 아니라,
    • “토큰 × 슬라이스” 수준에서 섞여 들어가 자연스럽게 부하가 분산됨.
  2. 파라미터 활용도 증가
    • 슬라이스들이 서로 다른 expert 조합을 거치면서,
    • 더 다양한 expert activation 패턴이 생겨,
    • “죽은 expert”가 줄어든다.
  3. 하위공간 전문화(sub-token specialization)
    • 어떤 expert는 주로 문법적 힌트가 강한 슬라이스를,
    • 다른 expert는 의미/도메인 정보가 강한 슬라이스를 처리하게 되는 경향이 관찰됨(해석 가능성).

1.4 기여 요약

논문은 크게 네 가지 기여를 주장함:

  1. SliceMoE 라우팅 단위의 변화
    • 토큰 단위가 아니라 연속 임베딩 슬라이스 단위 라우팅을 제안.
  2. 효율적인 구현
    • 슬라이스들을 expert별로 묶어서 batched GEMM/Fused kernel로 처리하는 전략.
  3. 광범위한 실험
    • WikiText-103 LM, WMT En–De MT, AG NEWS/DBPEDIA-14/EMOTION 분류 등에서
    • 토큰-MoE 및 dense baseline 대비 낮은 perplexity, 높은 accuracy, 더 나은 load balance를 보임. (예: WikiText-103에서 TokenMoE보다 12–18% PPL 개선, inference 최대 1.7× 가속).
  4. 해석 가능성 분석
    • PCA, Expert Specialization Score 등을 통해
    • expert들이 문법 vs 의미 sub-space에 특화되는 경향을 보인다고 분석.

2. 방법론(Architecture & Methods) 상세

SliceMoE의 핵심은 3장:

3. SliceMoE Architecture 및 서브섹션 3.1–3.4입니다.

2.1 기본 구조: embedding을 슬라이스로 나누기

  • 입력 토큰 hidden vector: [ hRdh \in \mathbb{R}^d ]
  • 이를 S개의 서로 겹치지 않는 연속 슬라이스로 분할: [ h(s)Rd/S,s=1,,Sh^{(s)} \in \mathbb{R}^{d/S},\quad s = 1, \dots, S ]
  • 여기서 슬라이스 인덱스 s는 “벡터 내 연속 구간”을 의미(임의 shuffle가 아님).

이후의 과정(라우팅, expert 처리)은 각 슬라이스 h^{(s)} 단위로 이루어진다.


2.2 Slice Router & Gating (3.1)

(1) 공통 slice router MLP

각 슬라이스 (h(s)h^{(s)})는 모든 토큰과 모든 슬라이스에 공통인(shared) router MLP를 거친다:

  • Router 구조:
    RouterMLP:Rd/SRHrRE\text{RouterMLP}: \mathbb{R}^{d/S} \rightarrow \mathbb{R}^{H_r} \rightarrow \mathbb{R}^E
    • 첫 번째 Linear: (Linear(d/SHr)\text{Linear}(d/S \rightarrow H_r))
    • 활성화: ReLU
    • 두 번째 Linear: (Linear(HrE)\text{Linear}(H_r \rightarrow E) )
    • (H_r = 256) (router hidden dim), (E) = expert 수.
  • 각 슬라이스에 대한 router 출력 logits: [ g(s)REg^{(s)} \in \mathbb{R}^E ]
  • softmax를 통해 expert별 확률:
    p(s)=softmax(g(s)),pe(s)=P(expert eh(s))p^{(s)} = \text{softmax}(g^{(s)}),\quad p^{(s)}_e = P(\text{expert } e \mid h^{(s)})

(2) Top-k expert 선택과 가중 합 라우팅

각 슬라이스 (s)에 대해:

  • (p(s)p^{(s)})에서 확률이 높은 top-k expert를 선택 (예: k=2).
  • j번째로 선택된 expert를 (e_j)라고 하면, 슬라이스를 expert 입력으로 보낼 때는 확률로 가중:

[

h~(s)j=p(s)ejh(s)\tilde{h}^{(s)}j = p^{(s)}{e_j} \cdot h^{(s)}

\tag{1}

]

  • 각 expert (e_j)는 이 가중된 슬라이스를 입력으로 받는 표준 FFN(두 단계 MLP) 임:
    • 예: ( FFN(x)=W2,σ(W1x+b1)+b2\text{FFN}(x) = W_2 ,\sigma(W_1 x + b_1) + b_2)

(3) 슬라이스 출력 재조합

슬라이스 (s)는 top-k expert에 분배되므로, 그 결과를 다시 모아야 함:

  • 한 슬라이스에 대해:
    • 각 expert 출력을 가중합하거나(보통 합),
    • 혹은 벡터 차원을 늘려 concat하는 방식도 가능.
  • 모든 슬라이스 (s=1,,Ss = 1,\dots,S)의 출력을 원래 순서대로 붙여 전체 토큰 표현 복원: [ h=[h(1),h(2),,h(S)]Rdh' = [h'^{(1)}, h'^{(2)}, \dots, h'^{(S)}] \in \mathbb{R}^d ]
  • 이렇게 얻은 (h')가 다음 Transformer 레이어로 넘어가는 토큰 representation이 된다.

정리하면:

“토큰 전체를 expert에게 보내는 게 아니라,
토큰 벡터 안의 연속 구간(slices)을 각각 라우팅 → expert 처리 → 다시 붙인다”
구조입니다.


2.3 Slice-Level Capacity Loss (3.2)

토큰-MoE에서도 capacity / load balance loss가 중요하듯,

SliceMoE는 슬라이스 단위의 부하 균형을 맞추기 위해 slice-level capacity loss를 도입함.

(1) expert별 슬라이스 수 세기

  • 미니배치에서:
    • 토큰 수: B
    • 슬라이스 수: S
    • 전체 슬라이스 개수: (B×S(B \times S)

각 expert (e)에 대해 할당된 슬라이스 개수를 count:

  • (countse=슬라이스 (b,s):expert e 에 라우팅됨\text{counts}_e = {\text{슬라이스 } (b,s) : \text{expert } e\ \text{에 라우팅됨}})

(2) 변동계수(CV)의 제곱으로 capacity loss 정의

  • counts 분포의 평균과 표준편차를 이용하여, CV(coefficient of variation) 의 제곱을 loss로 사용:
Lcap=α(std(counts1,,countsE)mean(counts1,,countsE))2(2)L_{\text{cap}} = \alpha \cdot \left( \frac{\text{std}(\text{counts}_1, \dots, \text{counts}_E)}{\text{mean}(\text{counts}_1, \dots, \text{counts}_E)} \right)^2 \tag{2}
  • (α\alpha): 하이퍼파라미터 (논문에서는 0.01~0.2 범위 사용).

이 loss의 의미:

  • expert별 슬라이스 수가 평균에서 많이 벗어날수록 (즉, 어떤 expert는 과부하, 어떤 expert는 거의 비활성) → 표준편차 / 평균 비율이 커져서 loss 증가.
  • 따라서 학습 과정에서 각 expert가 비슷한 양의 슬라이스를 처리하도록 유도.

(3) 통계적 멀티플렉싱(LLN 관점)

저자들은 이 fine-grained capacity loss가 “통계적 멀티플렉싱” 효과를 가진다고 설명함:

  • 토큰 수 B에 대해 토큰 단위 라우팅이면 B번의 coarse한 결정인데,
  • 슬라이스 단위 라우팅이면 B × S 번의 더 작은 결정이 존재.
  • 이때, 큰 수의 법칙(law of large numbers)에 의해:
    • 슬라이스가 많이 쪼개져 여러 expert로 골고루 흘러 들어가면서,
    • 자연스럽게 부하가 평균에 가까워지는 경향이 생긴다는 것.

2.4 Cross-Slice Dropout (3.3)

SliceMoE에서는 라우터가 특정 “슬라이스-익스퍼트 쌍”에 지나치게 집착하는 것을 막기 위해

Cross-Slice Dropout을 적용한다.

(1) 적용 위치

각 슬라이스 s에 대해:

  1. router softmax에서 top-k 라우팅 확률 (p(s)ejj=1k{p^{(s)}{e_j}}{j=1}^k) 를 얻는다.
  2. 이 확률 벡터 중 일부를 무작위로 0으로 만들고(dropout),
  3. 나머지 살아남은 확률들을 다시 정규화해서 합이 1이 되도록 한다.

예:

  • k=2, dropout rate=0.2라면
    • 선택된 2개 expert 중 한 개가 random drop될 확률이 0.2 정도가 되도록 설정.

(2) 목적

  • 라우터가 항상 같은 expert만 고르는 것을 방지하고,
  • 대체 expert 후보들을 탐색(explore) 하도록 유도.
  • 하지만 k개 중 일부만 떨어뜨리고, 나머지는 남기므로 정보 흐름이 완전히 끊기지는 않게 설계.

결과적으로:

  • 학습 초반부터 다양한 expert 조합을 시도하게 만들어
  • slice-level capacity loss와 함께 전반적인 라우팅 다양성을 증가시킨다.

2.5 Fused Kernels & 효율 구현 (3.4)

슬라이스 단위로 아주 잘게 쪼개서 라우팅하면,

“작은 행렬 연산을 엄청 많이” 실행해야 하는 비효율이 생길 수 있음.

이를 해결하기 위해 저자들은 expert별로 슬라이스를 묶어서 batched GEMM으로 처리:

  1. 미니배치에서,
    • “expert e로 라우팅된 (토큰, 슬라이스) 쌍”을 모아서 하나의 큰 batch 텐서로 스택.
      • 예: (inputseRNe×(d/S)\text{inputs}_e \in \mathbb{R}^{N_e \times (d/S)})
  2. expert e 내의 FFN 레이어마다,
    • 이 inputs_e를 한 번의 batched matrix multiply로 처리 (torch.bmm, CUTLASS, Triton, FlexGEMM 등).
  3. 이렇게 하면:
    • 커널 런치 오버헤드를 크게 줄이고,
    • 메모리 접근 패턴이 좋아져,
    • 실제로 dense FFN과 거의 비슷한 수준의 하드웨어 효율을 얻었다고 보고함 (특히 A100급 GPU에서).

즉, 슬라이스로 잘게 쪼갰지만,

실제 구현은 expert별로 다시 모아서 큰 배치 연산을 하면 된다.


3. (짧게) 실험 결과 포인트

서론/방법론 중심 요청이었지만, 큰 그림을 위해 결과도 핵심만:

  • 텍스트 분류(AG NEWS, DBPEDIA-14, EMOTION)
    • SliceMoE(S=8, k=2)는 TokenMoE보다 2–4%p 정확도 향상,
    • Dense DistilBERT와 비슷하거나 그 이상 성능을 내면서,
    • Expert Load Entropy(ELE) ≈ 0.95–0.97로 거의 완벽한 부하 균형 달성.
  • WikiText-103 LM
    • Dense: PPL ≈ 31.0
    • TokenMoE: 29.1
    • SliceMoE: 25.4 (가장 낮은 perplexity)
    • inference 속도도 dense 대비 최대 1.7× 빠름.
  • WMT En–De MT
    • Dense: BLEU 27.6
    • TokenMoE: 28.2
    • SliceMoE: 29.8, ELE ≈ 0.97로 부하 균형도 우수.
  • Slice 개수 S의 영향
    • S=2→4→8로 갈수록 성능이 좋아지다가,
    • S=16, 32에서는 살짝 하락 → 적당한 granularity(S=8) 가 좋다고 보고.

profile
NLP 공부합니당

0개의 댓글