[간단 논문 리뷰] Discovering the Gems in Early Layers: Accelerating Long-Context LLMs with 1000x Input Token Reduction
긴 컨텍스트로 인한 문제
- 현재 LLM은 컨텍스트 길이가 128K 늘어남으로써, 긴 컨텍스트 입력을 처리하는데 놀라운 역량을 보여주었지만, 이는 증가된 계산 리소스와 추론지연이라는 문제가 발생합니다.
- 그래서 GemFilter는 이러한 긴 컨텍스트 입력에 대해서도 LLM 추론을 가속화하고, GPU 메모리 소비를 줄이기 위한 새로운 접근 방법을 제시하였습니다.
- GemFilter는 기존 기술인 standard attention과 SnapKB/H2O에 비해 속도와 메모리 효율성을 크게 개선하였고, 특히 Sota 방법에 비해 2.4배의 속도향상 및 30% gpu 메모리 사용량 감소를 달성합니다.

LLM이 추론을 수행하는 과정
- 현재 디코더 모델인 LLM에서 텍스트를 빠르게 생성하기 위한 한 가지는 KV Cache 최적화입니다.
- 구체적으로, LLM은 이전 토큰에 대해서 다음 토큰을 예측하는 자기회귀적 성질을 가지고 있는데, 2단계를 거쳐 토큰을 생성합니다.
- 첫 번째로는 프롬프트 계산 단계로, LLM이 모든 계층에 대한 KV Cache를 계산하여 입력 토큰의 key, value를 저장합니다.
- 다음으로는 생성 단계에서 LLM은 미리 계산된 KV Cache를 사용하여 토큰을 반복적으로 생성하여 중복 계산을 방지합니다.
- 하지만 이러한 방식은 실행시간과 입력 문장이 길어질수록, KV Cache의 크기도 선형적으로 증가하여 계산량이 높아지게 됩니다.
- 이러한 문제를 해결한 것이 바로 GemFilter 입니다.
GemFilter
- 우리는 질의를 제공할 때, LLM이 종종 답을 생성하기 전에, 초기 레이어에서 필요한 정보를 찾는다는 것을 관찰하였습니다.
- 이를 통해 중요한 정보가 답변 생성 전에 인식이 된다는 것을 시사하여, 특정 필터 계층 내에서 초기 LLM 계층을 필터로 사용하여 입력토큰을 선택하고 압축하는 전략을 사용합니다.
- 기존에는 생성 단계에서 최적화를 진행하였지만, GemFilter는 생성 단계와 더불어 프롬프트 계산 단계에서 최적화를 수행하여 실행시간과 gpu 메모리 사용량 모두 줄였습니다.
- GemFilter는 LLM을 두 번 실행하는 것에서 시작합니다.
- GemFilter는 첫 번째 패스에서는 LLM의 초기 레이어만 실행하여 주요 입력 토큰을 선택합니다. 이는 프롬프트 계산 단계에 해당하여, 이 과정은 마지막 쿼리 토큰에서 가장 많은 어텐션을 받는 상위 K 토큰을 선택합니다.
- 두 번째 패스에서는 선택한 토큰을 전체 LLM에 공급하고 생성 함수를 실행합니다.
- 이를 통해서, 128K의 컨텍스트 길이를 1024 토큰으로 감소할 수 있습니다.
