LoRA: Low-Rank Adaptation of Large Language Models 리뷰

한량·2023년 5월 31일
2

0. Abstract

LoRA는 PEFT(Parameter Effecient Fine-Tuning)의 기법 중 하나이다. Pre-trained model의 weight는 고정한 채로, 몇 개의 dense(fc) layer만 학습시켜 downstream task의 연산량을 줄일 수 있다. GPT-3을 기준으로 parameter는 10000배, GPU 메모리는 3배를 줄일 수 있다. 또한 inference 과정에서 추가적인 latency가 없음

  • PEFT: 모델의 모든 파라미터를 튜닝하는 것이 아닌 일부 파라미터만을 튜닝함으로써 모델의 성능을 적은 자원으로도 높게 유지하는 방법론
  • Downstream task: pre-trained model을 사용해, 어떤 문제를 해결하기 위해 fine-tuning 하는것
  • Upstream task: Pre-train model을 학습시키는것
  • Latency: 어떤 요청의 시작부터 완료까지 걸리는 시간

1. Introduction

LLM은 기본적으로 pre-trained model을 특정 task에 맞게 fine-tuning을 시킴. 하지만 fine-tuning에서 모든 weight를 다시 학습시키면 GPT-2, GPT-3, RoBERTa 등 큰 모델의 경우 학습에 몇 달이 걸림.

이전 연구에서 over-parameterized model들은 low intrinsic dimension에 기반하고 있다는 사실에 기반해, 저자는 학습 과정에서도 모델은 low intrinsic rank을 갖고 있을 것이라 가정함.

LoRA는 기존 pre-trained weight는 고정하고, 몇 개의 dense layer만 rank decomposition matrices를 최적화하는 방식으로 학습시키기로 함.

위 그림처럼 기존 pre-trained weight WW는 고정하고 low rank decomposition된 weight A,BA, B만 학습시켜 WW에 더해줌. A,BA, B의 크기는 WW보다 작아 time, computational cost를 최대 3배까지 줄일 수 있음. 또한 task에 따라 LoRA module(A,BA, B)만 바꿔주면 되기 때문에 storage requirement, task-switching overhead를 줄일 수 있음. 이 외에도 추가적인 inference latency가 없다, 다른 기법들과 함께 적용이 가능하다는 장점이 있음.

1.1. Terminologies and Conventions

  • dmodeld_{model}: Transformer의 input/output dimension size
  • Wq,Wk,Wv,WoW_q, W_k, W_v, W_o: Self-attention module의 query/key/value/output projection matrices
  • W,W0W, W_0: Pre-trained weight
  • ΔW\Delta W: Adaptation 중 accumulated된 gradient update
  • rr: LoRA module의 rank
  • 이전 연구의 convention을 사용하고 optimizer는 Adam을 이용
  • Transformer MLP feedforward dimension dffn=4×dmodeld_{ffn} = 4 \times d_{model}

2. Problem Statement

LoRA는 agnostic하지만 본 논문에서는 language model에 집중함.

  • agnostic: model에 구애받지 않고 해석이 가능함
maxΦ(x,y)Zt=1ylog(PΦ(ytx,y<t))\max _{\Phi} \sum_{(x, y) \in \mathcal{Z}} \sum_{t=1}^{|y|} \log \left(P_{\Phi}\left(y_t \mid x, y_{<t}\right)\right)
  • PΦ(yx)P_{\Phi}\left(y \mid x\right): Φ\Phi로 parameterized된 pre-trained model
  • Z={(xi,yi)}i=1,...,N\mathcal{Z} = \{(x_i, y_i)\}_{i=1,...,N}: context-target쌍으로 된 학습 데이터셋, xi,yix_i, y_i는 token sequence

Fine-tuning 과정에서 model은 Φ0\Phi_0으로 init.되고 objective를 maximize하기 위해 Φ0+ΔΦ\Phi_0 + \Delta \Phi 로 업데이트됨. 각 downstream task를 위해 매번 Φ0|\Phi_0|와 같은 크기의 ΔΦ|\Delta \Phi|를 학습해 엄청난 cost가 발생.

maxΘ(x,y)Zt=1ylog(pΦ0+ΔΦ(Θ)(ytx,y<t))\max _{\Theta} \sum_{(x, y) \in \mathcal{Z}} \sum_{t=1}^{|y|} \log \left(p_{\Phi_0+\Delta \Phi(\Theta)}\left(y_t \mid x, y_{<t}\right)\right)

반면 위와 같은 LoRA 방식으로 fine-tuning할 경우 Φ0|\Phi_0| 전체가 아니라 그보다 작은 Θ|\Theta|를 찾아내는 방식으로 바뀌기 때문에 compute-/memory-effecient해짐. Θ|\Theta|는 최대 Φ0|\Phi_0|의 0.01%까지 작아질 수 있음.


3. Aren't Existing Solutions Good Enough?

기존에도 transfer learning에서 parameter-/compute-effecient를 위한 방법은 몇 가지가 있었음.

하지만 adapter layer를 추가하는 방식은 hardware parellelism이 없다면 작은 bottleneck layer만 추가해도 latency가 상당히 증가해 사용하기 어려웠음.

Prefix tuning은 optimize가 어려웠음.


4. Our Method

4.1. Low-Rank-Parameterized Update Matrices

h=W0x+ΔWx=W0x+BAxh=W_0 x+\Delta W x=W_0 x+B A x
  • W0Rd×kW_0 \in \mathbb{R}^{d \times k}
  • BRd×r,ARr×kB \in \mathbb{R}^{d \times r}, A \in \mathbb{R}^{r \times k}
  • rmin(d,k)r \ll min(d,k)

W0W_0는 고정하고 A,BA, B만 학습. 이후 W0W_0ΔW=BA\Delta W = BA는 같은 input xx에 곱해진 후 output vector끼리 coordinate-wise하게 sum.

AA는 random Gaussian init., BB는 zero-init.이라 ΔW\Delta W 또한 처음에는 zero-init. ΔWx\Delta W xα/x\alpha/x로 scaling됨. α\alpha는 learning rate처럼 tuning해서 r과 같은 값으로 설정. 실제 코드에서는 보통 r,αr, \alpha는 (8, 16)이나 (16,32)를 사용한다고 함.

	...
        # Actual trainable parameters
    	# define A, B
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()

	# initialize A, B
    def reset_parameters(self):
        nn.Embedding.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.zeros_(self.lora_A)
            nn.init.normal_(self.lora_B)

    def train(self, mode: bool = True):
        nn.Embedding.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
                self.merged = True
        
    def forward(self, x: torch.Tensor):
        if self.r > 0 and not self.merged:
        	# pre-trained weight W_0 * x
            result = nn.Embedding.forward(self, x)
            if self.r > 0:
            	# BA * x
                after_A = F.embedding(
                    x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
                    self.norm_type, self.scale_grad_by_freq, self.sparse
                )
                # W_0x + BAx
                result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return nn.Embedding.forward(self, x)

4.1.1. No Additional Inference Latency

LoRA를 이용하면 inference시 latency 성능 하락이 없음. 또한 다른 task에 사용할 경우엔 BABA만 제외하고 W0W_0로 학습한 다른 BAB'A'만 추가하면 되기 때문에 memory overhead가 낮음.

4.2. Applying LoRA to Transformer

본 논문에서는 trainable weight를 최소화하기 위해 LoRA를 attention weight만 적용하고 MLP module은 고정함. 이를 통해 GPT-3 175B를 기준으로 VRAM은 1.2TB에서 350GB, checkpoint size는 350GB에서 35MB로 줄임. 또한 학습 속도 또한 25% 정도 빨라짐.


5.Empirical Experiments

대부분의 경우에서 성능이 좋음

Transformer에서 한 projection matrix에 큰 r을 적용하는 것보다 모든 matrices에 작은 r을 적용하는 것이 더 성능이 좋았음.


+a) IA3

뉴럴네트워크의 Inner Activation을 줄이기도하고 늘리기도하는 어댑터를 중간에 삽입하는 방법론. 기존에 공개된 LoRA보다 적은 파라미터를 사용하면서 높은 성능을 내는 것으로 알려져있으며, GPT-3를 in-context learning 했을때 보다도 성능이 좋다 라고 주장하고 있음. 학습시간도 매우 짧아 A100 GPU 하나로 30분만에 튜닝할 수 있었다고 함.


+aa) LoRA 사용법

  1. loralib 설치
pip install loralib
# Alternatively
# pip install git+https://github.com/microsoft/LoRA
  1. 기존 nn.Linear, nn.Embedding, nn.Conv2dlora.~로 대체
# ===== Before =====
# layer = nn.Linear(in_features, out_features)

# ===== After ======
import loralib as lora
# Add a pair of low-rank adaptation matrices with rank r=16
layer = lora.Linear(in_features, out_features, r=16)
  1. 학습 전, lora parameter만 학습 가능하게 설정
import loralib as lora
model = BigModel()
# This sets requires_grad to False for all parameters without the string "lora_" in their names
lora.mark_only_lora_as_trainable(model)
# Training loop
for batch in dataloader:
   ...
  1. checkpoint를 저장할 때엔 state_dict가 LoRA parameter만 저장하게 함.
# ===== Before =====
# torch.save(model.state_dict(), checkpoint_path)
# ===== After =====
torch.save(lora.lora_state_dict(model), checkpoint_path)
  1. checkpoint를 불러올 때엔 load_state_dict에서 strict=False로 설정.
# Load the pretrained checkpoint first
model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False)
# Then load the LoRA checkpoint
model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)

Reference

profile
놀고 먹으면서 개발하기

0개의 댓글