Effective Long-Context Scaling of Foundation Models

SUNGYOON LEE·2023년 10월 3일
2

Abstract

  • 본 논문은 32k 컨텍스트 윈도우를 지원하는 long-context LLM을 제시함.

학습 데이터셋

  • continual pretraining 방식을 활용하여 LLAMA 2를 학습함.
  • 데이터셋으로는 longer training 시퀀스들을 활용하고, 데이터셋도 long text를 업샘플링하여 활용함.

효과

  • 이러한 방식으로 더 긴 글에 대해 강건함을 보여주고, LLAMA2 70B long 버전은 gpt-3.5-turbo-16k의 성능도 넘음.
  • 심지어 cost-effective한, human-annotated long instruction 데이터도 필요없는 instruction tuning 과정을 통해 이러한 성능을 얻음.

ablation experiment

  • ablation experiment를 통해 pretrain dataset에 존재하는 풍부한 긴 컨텍스트 데이터셋이 강력한 성능을 얻는데 그다지 중요햔 키포인트가 아님도 밝혀냄.
  • long context의 continual pretraining이 더 효과적이고, pretrain할 때 long context를 학습하여 스크래치부터 학습하는 것과 유사한 성능을 보여줌.

Introduction

  • 기존에 long context LLM은 LLM API를 통해 이용되고 있음.
  • 예를 들어 gpt-3.5-turbo-16k나 claude 등이 있음.
  • 본 연구에서는 LLAMA2를 continual pretraining을 하였음.
  • 400B 토큰을 추가로 학습하였고, long training sequences 형태로 구성함.

Method

Continual Pretraining

  • longer sequence length를 학습하는 것은 엄청난 계산량을 요구함. 이러한 점에 의해 continual pretraining 접근 방식을 채택하게 됨.
  • continual pretraining을 할 때 최대한 원본 아키텍처를 유지하려고 했고 몇가지만 수정을 함.
    • longer context를 받기 위해 positional encoding만 수정을 진행.
    • sparse attention은 적용을 하지 않음. sparse attention은 기본적으로 모델 차원이 매우 커질 때 이점을 가지게 되는데, 70B 모델의 dimension이 이미 8192(8k)이고, attention을 계산하고, 값을 합할 때 bottleneck현상이 생기는 것은 49152(48k)이상이라서 굳이 우리가 추구하는 32k에 필요없는 것이라 적용하지 않음.

Positional Encoding

  • 기존의 positional encoding방식은 먼 토큰간의 관계 및 정보를 취합하는데 어려움을 겪어서 long context에 강건할 수 있도록 RoPE positional encoding방식으로 교체함. 여여기서 기본 RoPE를 사용한 게 아닌 hyperparameter “base frequency b”에 의해 조절되는 rotation angle을 수정함.

Data Mix

  • RoPE를 적용시켜 long context에 대한 성능 향상을 위해 LLAMA 2의 pretraining data의 비율을 조절하거나, long text 데이터를 추가함.
  • 그러나, context의 길이보다는 데이터의 퀄리티가 더 중요한 것을 발견함.

Optimization Details

  • 7B, 13B, 34B, 70B 모두 sequence의 길이는 늘림.
  • 다만 배치당 토큰 수는 동일하게 가져감. 10만스텝동안 400B 토큰 추가 학습함.
  • learning rate는 2e-5, lr_scheduler는 2000 warm-up step으로 하고, cosine 형태의 lr_scheduler를 활용함.

Instruction Tuning

  • LLM alignment용 데이터셋을 구하기에는 비용이 많이 듬. 심지어 long context면 비용이 더 많이 듬. 그래서 open source 데이터를 찾아보면 길지 않음.
  • 그래서 쉽고 간단한 방법을 제시함.
  • 사전에 구축된 대량의 다양한 short-prompt datasetlong-context benchmark에도 잘 작동하는 것을 발견함.
  • 따라서, RLHF 데이터셋과 long-context 시나리오로 구성된 self-instruct 데이터 등을 이용함.
  • 생성 데이터는 long document를 활용하여 QA 형태로 구성함.
  • 길이에 따른 instruction data는 다음과 같이 구성함.
    • 짧은 instruction data
      • 각각을 합쳐서 16k의 토큰 길이로 만듦.
    • 긴 instruction data
      • 오른쪽에 패딩을 더함.
  • 일반적으로 instruction tuning을 할 때는 output에 대해서만 loss를 구하지만, input prompt에도 loss를 구하는 것이 일관적인 성능 향상에 도움이 되는 것을 발견함.

Results

Conclusion

  • positional encoding 방식 변경, continual pretraining 방식으로 gpt-3.5-turbo-16k 성능을 능가함.
profile
매일 매일 한 걸음씩 나아가고자 합니다.

0개의 댓글