PaLM: Scaling Language Modeling with Pathways (2)

tobigs-nlp-article·2022년 6월 1일
0

예고했듯, (2)편에서는 초대규모 언어모델의 효율적인 학습을 가능하게 한 Pathways와 그를 이해하기 위해 필요한 선행 지식들에 대해 공부할 것이다. 또한, 모델을 실험하기 위해 사용된 새로운 벤치마크 및 방법들과, few-shot의 성능을 크게 올린 chain-of-thought prompting 등에 대해 공부할 것이다.

4. Training Infrastructure

PaLM은 JAX와 T5X를 사용하며, 두 대의 TPU v4 Pods를 이용해 학습을 진행한다.

참고로, JAX는 구글이 만든 머신러닝 라이브러리로, 넘파이를 GPU에서 연산할 수 있게 하여 기존 넘파이를 개선, 훨씬 빠르게 연산을 진행한다. 또한, JIT (just in time) 컴파일 기법과 XLA 컴파일러 (Accelerated Linear Algebra)를 사용하여 런타임에 사용자가 생성한 TensorFlow 그래프를 분석하고 실제 런타임 차원과 유형에 맞게 최적화 및 여러 연산을 함께 합성하고 이에 대한 효율적인 기계어 코드를 내보낸다고 한다.

JIT 컴파일은 프로그램을 실제 실행하는 시점에 기계어로 번역하는 컴파일 기법이다.

XLA (Accelerated Linear Algebra)는 Tensorflow의 서브 프로젝트로 그래프 연산의 최적화 / 바이너리 사이즈의 최소화 등을 목적으로 하는 컴파일러. 해당 블로그에 잘 정리되어있습니다!

기존의 대규모 언어모델들은 일반적으로 두 가지 방법을 사용해 학습을 진행한다.
1. single TPU
LaMDA /GLaM이 대표적

  1. pipeline parallelism
    • Megatron-Turing NLG 530B:
      • 2240 A100 GPUs & model, data, pipeline parallelism
    • Gopher (Rae et al., 2021a)
      • four DCN-connected TPU v3 Pods
      • pipelining between pods.

명확한 이해를 위해 Data/model/pipiline parallelism에 대해 짚고 넘어가자.

  • Data Parallelism
    말 그대로 여러 대의 gpu에 데이터를 분산하여 시간을 단축하는 방법. weigth parameter를 업데이트할 때마다 여러 gpu가 학습한 결과를 종합한 후 다시 나누는 synchronization (all-reduce)이 필요하다.

  • Model Parallelsim
    모델 사이즈가 너무 커서 하나의 gpu 메모리에 다 들어가지 않는 경우, 여러 gpu에 모델 파라미터를 나누어 연산하는 것. 일반적으로 tensor parallelism과 pipeline parallelsim으로 나누어진다.

  • Tensor Parallelsim
    커다란 행렬을 여러 gpu로 나누어 연산한 후 그 결과값을 concat하는 방식

  • Pipeline parallelsim


    • 순차적으로 이어진 레이어 혹은 레이어 안의 stage를 gpu들이 나누어서 하는 방식. 순차적으로 해야할 일을 나눠하는 개념이기에 throughput의 증가는 없음.

    • 한 time에 하나의 gpu만 활성화되어 있다는 문제점을 완화하기 위해 입력 미니배치를 마이크로배치로 나누어 뿌림.
    • [단점]
      1. 여러 gpu가 활성화되지 않는 bubble 존재
      2. 마이크로배치들/weight를 계속해서 올리는 과정에서 higher memory bandwidth를 요구하게 됨

PATHWAYS

PaLM에서는 pipeline을 사용하지 않으면서 540B나 되는 큰 모델을 6144개의 칩으로 확장시키기 위해 다음의 전략을 사용한다.

첫째, 각 TPU v4 Pod은 full copy of the model parameters를 가진다. 이때, 각 weight tensor는 12-way model parallelism을 통해 3072개의 chip에 나누어져 있고, 마찬가지로 데이터도 256-way fully sharded data parallelism를 통해 분산되어있다.

둘째, 각 pod 단위에서 two-way data parallelism을 적용하기 위해, client-server architecture를 이용한다.

이때, 하나의 Python client는 학습 배치의 반을 나누어 각 pod에게 쏘아준다. 그럼 각 pod은 forward/backward를 standard within-pod data and model parallelism을 이용해 수행한다.

각 pod은 반 개의 배치를 이용해 얻은 gradients의 결과를 공유하고 다음 timestep을 위한 파라미터를 준비한다.

위의 사진을 보면서 이해를 해보자.

Python client는 상단의 왼쪽과 같이 dataflow program 을 구성한다.

component A는 각 pod의 forward+backward computation 결과이다. component B는 optimizer update 정보를 담고 있으며, cross-pod gradient transfer가 존재한다.

The Pathways program은 처음에 각 pod으로 하여금 component A 를 연산하게 하고, 각 결과를 pod끼리 공유하게 한다. 마지막으로 component B를 연산한다.

Pathways system은 여러 대의 accelerator chips를 활용하여 모델을 훈련시킬 수 있는 특징들을 가진다.
1. asynchronous gang-scheduling를 각 pod scheduler에 적용, 즉, 여러 대의 프로세스가 소통하면서 동시에 작동하도록 함으로써, latency를 줄임
2. sharded-dataflow execution model을 사용하기 때문에 data transfer 시간을 줄임 (refer to Barham et al. (2022) for details).

5. Training Efficiency

연구자들은 성능을 평가하기 위해 MFU(model FLOPs utilization)라는 새로운 지표를 소개한다. 이는 모델이 어떻게 구현되었는지에 구애받지 않고 더 정확한 성능을 측정할 수 있다.

이는 관측된 throughput, 즉, 일초당 처리할 수 있는 토큰의 수와 이론적 최대 throughput과의 비율이다. 다시 말해, 기대되는 최대 throughput과 현실 사이의 비교라고 할 수 있다.

6. Training Setup

  • Weight initialization

    • input embeddings
      • E ∼ N (0, 1)
    • pre-softmax output logits을 1/√n로 scale
    • kernel weights (embeddings & layer norm scales 제외 나머지)
      • “fan-in variance scaling”, i.e., W ∼ N(0,1/√n_in),
        • n_in : input dimension
  • Optimizer

    • Adafactor optimizer
      • Adam (Kingma & Ba, 2014) with “parameter scaling,
  • Sequence length: 2048

  • Dropout: X

7. Evaluation

7.1 English NLP tasks

7.1.1 Finetuning

PaLM model을 SuperGLUE benchmark로 파인 튜닝해 성능을 비교해보았다.

6을 확인하면 파인튜닝했을때, sota에 비건하는 성능을 보인다는 것을 확인할 수 있으며, 7에서는 아직까지 퓨샷보다는 파인튜닝이 성능이 훨씬 좋긴하다는 것을 확인할 수 있다. 마지막으로 8에서 그래도 디코더 모델들중에서는 가장 성능이 좋다는 것을 강조한다.

7.2 BIG-bench

150 tasks : logical reasoning, translation, question answering, mathematics, 등등

  • goal step wikihow
    • 이벤트 사이의 순서 맞추기. Example:
      • Input: In order to ”clean silver,” which step should be done first? (a) dry the silver (b) handwash the silver
      • Answer: (b) handwash the silver
  • logical args
    • 논리적으로 맞는 inference 고르기. Example:
      • Input: Students told the substitute teacher they were learning trigonometry. The substitute told them that instead of teaching them useless facts about triangles, he would instead teach them how to work with probabilities. What is he implying? (a) He believes that mathematics does not need to be useful to be interesting. (b) He thinks understanding probabilities is more useful than trigonometry. (c) He believes that probability theory is a useless subject.
      • Answer: (b) He thinks understanding probabilities is more useful than trigonometry.
  • english proverbs
    • 어떤 속담이 주어진 문서를 가장 잘 설명하는지 고르기. Example:
      • Input: Vanessa spent lots of years helping out on weekends at the local center for homeless aid. Recently, when she lost her job, the center was ready to offer her a new job right away. Which of the following proverbs best apply to this situation? (a) Curses, like chickens, come home to roost. (b) Where there is smoke there is fire (c) As you sow, so you shall reap.
      • Answer: (c) As you sow, so you shall reap.
  • logical sequence
    • 어떤 것이 논리적으로 올바른 순서를 가지는지 고르기. Example:
      • Input: Which of the following lists is correctly ordered chronologically? (a) drink water, feel thirsty, seal water bottle, open water bottle (b) feel thirsty, open water bottle, drink water, seal water bottle (c) seal water bottle, open water bottle, drink water, feel thirsty
      • Answer: (b) feel thirsty, open water bottle, drink water, seal water bottle
  • navigate
    • navigational instructions을 따랐을 때 결과 맞추기. Example:
      • Input: If you follow these instructions, do you return to the starting point? Always face forward. Take 6 steps left. Take 7 steps forward. Take 8 steps left. Take 7 steps left. Take 6 steps forward. Take 1 step forward. Take 4 steps forward.
      • Answer: No
  • mathematical induction
    • 실제 수학과 다르더라도, 규칙을 주었을 때 맞추는지 확인. Example:
      • Input: It is known that adding 2 to any odd integer creates another odd integer. 2 is an odd integer. Therefore, 6 is an odd integer. Is this a correct induction argument (even though some of the assumptions may be incorrect)?
      • Answer: Yes

7.3 Reasoning

  • chain-of-thought prompting
    최근의 연구들에 따르면, reasoning의 중간 단계를 함께 만들게 함으로써 답변의 질이 비약적으로 향상되었다고 한다. 이를 chain-of-thought라고 한다. 사진을 통해 이해해보자.

7.4 Code Tasks

코드 태스크를 평가하는 지표로는 pass@k metric를 사용한다. pass@k metric는, 모델로 하여금 K개의 샘플을 만들게 하고, 이 중 하나라도 문제를 해결하면 맞다고 하는 지표이다.

7.5 Translation

  • English-centric language pairs
  • Extremely-low resource language pairs

7.6 Multilingual Natural Language Generation

  • MLSum (Scialom et al., 2020) – Summarize [de/es]
  • WikiLingua (Ladhak et al., 2020) – Summarize [en/es/ru/tr/vi → en]
  • XSum (Narayan et al., 2018) – Summarize : single sentence. [en]
  • Clean E2E NLG (Novikova et al., 2017; Duˇsek et al., 2019) – one or two sentences. [en]
  • Czech Restaurant response generation (Duˇsek & Jurˇc ́ıˇcek, 2019) – Given: dialog context -> generate : response [cz]
  • WebNLG 2020 (Gardent et al., 2017; Castro Ferreira et al., 2020) – subject-predicate-object이 주어지면, 완벽한 문장으로 만들기 [en/ru]

해당 실험의 의의는 다음과 같다:

  • Effectiveness of finetuning

    • 특히 요약 태스크에서는 finetuning이 아직 중요했다.
  • Generation quality of English vs. non-English

    • PaLM은 영어 생성에 있어서는 5/6의 요약 태스크에서 sota를 달성하였다(심지어 input이 영어일때도). 하지만 non-English summariation (MLSum) 에서는 좋은 성능을 보이지 못했으며 few-shot and finetuning과의 간극도 굉장히 컸다. 즉, 외국어 생성에 있어서는 더 나쁜 성능!
  • 1-shot vs. finetuning gap

    • Data-to-Text results를 보면 few-shot과 best finetuned results의 간극이 굉장히 줄어들었다는 것을 볼 수 있다.
  • Few-shot summarization

    • large improvement from 8B → 62B, & smaller-but-significant improvement from 62B → 540B.

7.7 Multilingual Question Answering

TyDiQA-GoldP benchmark

비영어 데이터가 별로 없었음에도 불구하고 꽤나 좋은 결과를 낸다는 것을 확인할 수 있다. (≈ 22% of the 780B training tokens). 비교하자면, mT5는 PaLM보다 6배 많은 비영어 데이터로 훈련을 시켰으며 ByT5는 1.5배 더 많은 비영어 데이터를 훈련시켰다.

profile
투빅스 NLP 논문세미나입니다!

0개의 댓글