QLoRA : Efficient Finetuning of Quantize LLMs
๐ https://arxiv.org/abs/2305.14314
๊ธฐ์
ํด์ปคํค์ ํ์ํ ๋ด์ฉ์ ๋์ถํ๊ธฐ ์ํด ๋
ผ๋ฌธ์ ์ ๋ฌธ์ ๊ผผ๊ผผํ๊ฒ ๋ณด์ง๋ ๋ชปํ๊ณ , ํต์ฌ ๋ถ๋ถ๋ง ๋น ๋ฅด๊ฒ ์ ๋ฆฌํ์๋ค. ์๋ต๋ ๋ด์ฉ์ด ์์ ์ ์์ผ๋ ์ฐธ๊ณ ๋ฐ๋
QLoRA : LoRA ์ nf4๋ก ์์ํ ํ pretrained language model ์ ์ฌ์ฉํ๋ ๋ฐฉ์

QLoRA ์ ๋ฐฉ๋ฒ๋ก :
- NF4
- Double quantization
- Paged Optimization
์๋๋ถํฐ QLoRA ๋ฅผ ์ํ ๊ธฐ๋ณธ ์ง์๋ถํฐ ๋ฐฉ๋ฒ๋ก ์ ๋ํด์ ์ค๋ช
ํ๋ ค ํ๋ค.
Quantization :
์ฐ์ฐ๊ณผ ์ ์ฅ์ ํ์ํ ์ ๋ฐ๋๋ฅผ ๋ฎ์ถ์ด ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋๊ณผ ๊ณ์ฐ ํจ์จ์ ๋์ด๋ ๊ธฐ์
Quantization ์ข
๋ฅ :
- Post-training Quantization (PTQ)
- ๋ชจ๋ธ ํ์ต ํ ์์ํ ์ํ
- ๋น๊ต์ ๊ตฌํ์ด ์ฝ๋ค๋ ์ฅ์ ์ด ์กด์ฌํ๋ค.
- ์ ํ๋ ์์ค ๊ฐ๋ฅ์ฑ ์กด์ฌ
- Quantization-Aware training (QAT)
- ํ์ต ์ค ์์ํ๋ฅผ ๊ณ ๋ ค
- ์ ํ๋ ์์ค์ด ์ ์ง๋ง ๊ตฌํ ๋ณต์ก๋๊ฐ ๋์
- Dynamic Quantization / Static Quantization
- Dynamic : ์ถ๋ก ์ค์ ์์ํ
- Static : ์ฌ์ ์ ์์ํ ๋ ๊ฐ ์ฌ์ฉ
PTQ ์ ์ ์์ํ ๋ฐฉ์ ์์ :
-
absmax(X[fp32]) : fp32์ X์์ ์ ๋๊ฐ์ด ๊ฐ์ฅ ํฐ ์น๊ตฌ๋ฅผ ๋ฝ๋๋ค.
Xfโp32=[1.5,โ2.3,0.7,โ1.2]
S=max(abs(X))=2.3
-
127/ [1๋ฒ๊ฒฐ๊ณผ] : 8 bit tensor์ max abs ๊ฐ์ธ 127์ ๋๋๋ค โ quantization constant (์์ํ ์์)
scaleย factor=S127โ=2.3127โโ55.22
-
๋์จ ๊ฐ์ x ์ ๋ค ๊ณฑํด์ค, ๊ทธ๋ฆฌ๊ณ ๋ฐ์ฌ๋ฆผ
Xquantizedโ=round(Xรscaleย factor)
Xquantizedโ=round([1.5ร55.22,โ2.3ร55.22,0.7ร55.22,โ1.2ร55.22])
Xquantizedโ=round([82.83,โ127,38.65,โ66.26])
Xquantizedโ=[83,โ127,39,โ66]
-
quantile quantization : ๊ฐ ์์ํ ๊ตฌ๊ฐ์ ํ ๋น๋๋ ๋ฐ์ดํฐ ์๊ฐ ๊ท ๋ฑํ ์์ํ
-
์ ๊ท๋ถํฌ ์ ์ ํ์ ์์ํ ๊ตฌ๊ฐ์ ๋ง๋ค๊ณ , ์ด๋ฅผ [-1,1] ์ฌ์ด๋ก ์ค์ผ์ผ๋ง ๋ ๊ฐ์ค์น ํ
์๋ฅผ ์์ํ ๊ตฌ๊ฐ์ ๋งตํํ๋ค. 16๊ฐ๋ก ์์ํ ๊ตฌ๊ฐ๋ ๋๋ ์ฃผ๋๋ฐ ์ด๋ 4bit ์์ํ์ฌ์ 16๊ฐ (nf4)
NF4 (NormalFloat 4-bit)
1. ๋ฌธ์
โ โ๋ถํฌโ ๋ฅผ ๊ณ ๋ คํ ์์ํ์ ํ์์ฑ
์ผ๋ฐ์ ์ผ๋ก, ์์ํ(Quantization) ์ ๋ค์ ๋ ๊ฐ์ง๋ฅผ ๋ชจ๋ ๊ณ ๋ คํด์ผ ํ๋ค.
- ์ ๋ฐ๋ ์์ค ์ต์ํ
- ์ฐ์ฐ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ํจ์จ
์ด ๋, Quantile Quantization ์ ์ ๋ณด ์ด๋ก ์ ์ผ๋ก โ๊ฐ ์์ํ ๊ตฌ๊ฐ (bin)โ ์ ํ ๋น๋๋ ๋ฐ์ดํฐ ์๊ฐ ๊ท ๋ฑโ ํ๋๋ก ์ค์ ํ๋ค. ์ฆ, ๋ฐ์ดํฐ๊ฐ ์ด๋ค ๋ถํฌ๋ฅผ ๊ฐ์ง๊ณ ์๋๋ผ๋, ๋ชจ๋ bin ์ ๊ฐ์ ๊ฐฏ์์ ๊ฐ์ด ์๋๋ก ๊ฒฝ๊ณ๋ฅผ ์ค์ ํ๋ ๋ฐฉ์์ด๋ค.
ex : ๋ฐ์ดํฐ๊ฐ 0 ๊ทผ์ฒ์ ๋ชฐ๋ ค ์์ผ๋ฉด, ๊ตฌ๊ฐ์ ๋ ์ด์ดํ๊ฒ ์ชผ๊ฐ์ 0 ๋ถ๊ทผ์ ์ ๋ฐ๋๋ฅผ ๋์ด๊ณ , ๋ฐ์ดํฐ๊ฐ ๊ฑฐ์ ์๋ ๊ทน๋จ ์์ญ(์์๋ผ์ด์ด)์๋ ๊ตฌ๊ฐ์ ๋๊ฒ ํ ๋นํจ.
๋ฌธ์ ๋ ์ด โ๋ถํฌ๋ฅผ ์ง์ ์ถ์ ํ๋ ๊ณผ์ (quantile estimation)โ์ด ๋น์ธ๊ณ (๊ณ์ฐ๋ ๋ง๊ณ ), ์์๋ผ์ด์ด๊ฐ ์์ผ๋ฉด ์ค์ฐจ๊ฐ ์ปค์ง๋ค๋ ๊ฒ.
2. ๋ฅ๋ฌ๋ ๋ชจ๋ธ ๊ฐ์ค์น (weights) ๋ โ์ ๊ท๋ถํฌโ ์ ๊ฐ๊น๋ค.
key idea : ์ด๋ฏธ ํ์ต๋ ์ ๊ฒฝ๋ง์ ๊ฐ์ค์น๋ ๋์ฒด๋ก 0์ ์ค์ฌ์ผ๋ก ํ๋ ์ ๊ท๋ถํฌ์ ๊ฐ๊น๋ค.
์ฆ โํ์ต๋ ๊ฐ์ค์น๊ฐ ๋ณดํต ยฑ ๋ช ํ์คํธ์ฐจ(ฯ) ์์ ๋ชฐ๋ ค์๋คโ ๋ ํต๊ณ์ ์ฑ์ง์ ์ด์ฉํ์๋ ๊ฒ.
โ โ์์ ํ์ค ์ ๊ท๋ถํฌ์ ๋ํด์ ์ด๋ก ์ ์ผ๋ก ์ต์ ์ธ ์์ํ ๊ตฌ๊ฐ์ ๋ฏธ๋ฆฌ ๊ณ ์ ํด๋ฒ๋ฆฌ์โ ๋ผ๋ ์ ๋ต์ด๋ค.
โ ํ๋ง๋๋ก, ๋งค๋ฒ ๊ฐ์ค์น์ ๋ถ์์๋ฅผ ์๋ก ์ถ์ ํ ํ์ ์์ด, โํ์ค ์ ๊ท๋ถํฌ์ฉ์ผ๋ก ๋ฏธ๋ฆฌ ๊ณ์ฐํด๋ ๊ตฌ๊ฐ ๊ฒฝ๊ณโ๋ฅผ ํ์ฉํ๋ฉด ๋๋ค๋ ๊ฒ.
( ์ฌ๊ธฐ์ NFk ์์ํ ์์ด๋์ด๋ฅผ ์ป์ ์ ์๋ค. )
3. NormalFloat(NFk) ์์ด๋์ด :
0 ํ๊ท , ฯ(ํ์คํธ์ฐจ) ๊ฐ ์ผ์ ํ ์ ๊ท๋ถํฌ์ ๋ง์ถฐ์ ๋ฏธ๋ฆฌ ์์ํ ๊ตฌ๊ฐ์ ๋ง๋ ๋ค, ๊ฐ์ค์น ํ
์๊ฐ ์ค์ ๋ก๋ ๋ค๋ฅธ ํ์ค ํธ์ฐจ๋ฅผ ๊ฐ์ง ์ ์์ผ๋, ์ด๋ฅผ [-1, 1] ๋ฒ์๋ก ์ ๊ทํ (์ค์ผ์ผ๋ง) ํด ๋งคํํ๋ค.
โ ์ ๊ท๋ถํฌ ์ ์ ํ์ ์์ํ ๊ตฌ๊ฐ์ ๋ง๋ค๊ณ , ์ด๋ฅผ [-1,1] ์ฌ์ด๋ก ์ค์ผ์ผ๋ง ๋ ๊ฐ์ค์น ํ
์๋ฅผ ์์ํ ๊ตฌ๊ฐ์ ๋งตํํ๋ค. 16๊ฐ๋ก ์์ํ ๊ตฌ๊ฐ๋ ๋๋ ์ฃผ๋๋ฐ ์ด๋ 4bit ์์ํ์ฌ์ 16๊ฐ
๊ฐ๋จํ ์ค๋ช
:
[-1,1] ์ฌ์ด์ 4bit ์ 16๊ฐ๊ฒฉ์ผ๋ก ์ชผ๊ฐ์ ๊ทธ ์์ ์๋ ๊ฐ์ ๊ฐ๊น์ด ๊ฑฐ๋ก ๋ฐ๊พธ๊ธฐ
(๋จ์ ๊ท ๋ฑ๊น์ง๋ ์๋๊ณ โ์ ๊ท๋ถํฌ์ ๋ถ์์โ ๋ฅผ ๊ณ ๋ คํด ์ต์ ํ๋ ์ง์ ์ ์ฐพ๊ธฐ)
- [-1, 1] ๊ตฌ๊ฐ์ 16๊ฐ๋ก ๋๋๋ค (k=4๋นํธ โ 16๊ฐ bin)
- ์ ํํ๋ ๋จ์ ๊ท ๋ฑ ๊ฐ๊ฒฉ์ด ์๋๋ผ, โ์ ๊ท๋ถํฌ์ ๋ถ์์โ๋ฅผ ๊ณ ๋ คํด ์ต์ ํ๋ ์ง์ ์ ์ฐพ์ (๋
ผ๋ฌธ์์๋ ์ด๋ก ์ ์ต์ ์ด๋ผ ํจ).
- FP32 ๊ฐ(์ค์ ๊ฐ์ค์น)์ ํด๋น ๊ตฌ๊ฐ์์ ๊ฐ์ฅ ๊ฐ๊น์ด ๋ํ๊ฐ(๊ตฌ๊ฐ์ ์ค์๊ฐ ๋ฑ)์ผ๋ก ๋งคํ
- ์: 0.71 โ 0.7333 ๊ตฌ๊ฐ์ ํด๋นํ๋ฉด, โ0.7333โ์ผ๋ก
- ์ด๋ ๊ทธ ๋ํ๊ฐ(ํน์ ๊ตฌ๊ฐ)์ ํด๋นํ๋ ์ธ๋ฑ์ค(์: 14๋ฒ bin)๋ฅผ ์ ์ฅ
- ๋ฉ๋ชจ๋ฆฌ์ 14๋ผ๋ โ4๋นํธโ ์ซ์๋ง ์ ์ฅํ๋ฉด, ์ญ์ผ๋ก ๋ณต์ํ ๋ 14โ0.7333 ์ ๋๋ก ๋ค์ ํ์ฅ ๊ฐ๋ฅ.
NF4 ์ ๋ฆฌ
- โ๋ฅ๋ฌ๋ ๊ฐ์ค์น๋ (๊ฑฐ์) N(0,ฯ^2) ๋ถํฌ๋ฅผ ๋ฐ๋ฅธ๋คโ๋ ์ ์ ์ ๊ทน ํ์ฉ!!
- โ์ ๊ท๋ถํฌ์ ์ต์ ์ธ ์์ํ ๊ตฌ๊ฐโ์ ๊ณ ์ (precomputed)ํด๋๊ณ ,
- ์ ๊ท๋ถํฌ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์์ํ ๋ ๊ตฌ๊ฐ์ ๊ณ ์ ํด๋๊ธฐ ๋๋ฌธ์ ์ถ๊ฐ ๊ณ์ฐ์ด ๋ถํ์
- quantile quantization ์ผ๋ก ํ๋ฉด ๋งค๋ฒ ๊ณ์ฐ์ ํด์ค์ผ ๋ผ์ ๊ณ์ฐ cost ๊ฐ ์๋นํจ
- โ๊ฐ์ค์น๋ฅผ [-1,1]๋ก ๋งคํโ ๋ง ํด์ฃผ๋ฉด ์์ฝ๊ฒ 4๋นํธ ์์ํ๋ฅผ ์ ์ฉํ ์ ์๊ฒ ํด์ฃผ๋ ๋ฐฉ๋ฒ์
๋๋ค.
- [-1,1] ์ฌ์ด๋ก ์ค์ผ์ผ๋ง ๋ ๊ฐ์ค์น๋ฅผ ์ ๊ท๋ถํฌ์ ์ต์ ์ธ ์์ํ ๊ตฌ๊ฐ์ ๋งตํ ํด์ฃผ๊ธฐ
- ํ์ต๋ ๊ฐ์ค์น ํ
์๋ฅผ [-1,1] ๋ฒ์๋ก ์ค์ผ์ผ๋งํ๋ฉด, ๋ฏธ๋ฆฌ ์ ์๋ ์ ๊ท๋ถํฌ ๊ธฐ๋ฐ ์์ํ ๊ตฌ๊ฐ (16 bin) ์ ๋ฐ๋ก ๋งตํ ๊ฐ๋ฅ
- ์ฑ๋ฅ์ ์ ์งํ ์ ์๋ ์ด์ ๋ ์์ํ ๊ตฌ๊ฐ์ด ์ฌ์ ํ์ต๋ ๊ฐ์ค์น์ ์ ๊ท๋ถํฌ์ ํน์ฑ์ ๋ง๊ฒ ์ค๊ณ๋์์ผ๋ฏ๋ก, ๊ฐ์ค์น์ ์ค์ฌ๋ถ์ ๋ ์ด์ดํ ๊ตฌ๊ฐ์ด ๋ฐฐ์น๋์ด ์ ๋ฐ๋๋ฅผ ์ ์งํ ์ ์๋ค.
QLoRA ์์๋ nf4 ์์ํ๋ฅผ ํตํด์ PLM์ quantization ์ํจ๋ค.
Double Quantization
Quantization Constant : ์์ํ ์์
์์ํ๋ ๋ฐ์ดํฐ๋ฅผ ์์ถํ์ง๋ง, ๋ฐ์ดํฐ ๋ฒ์ (ex : min, max) ๋ฅผ ๊ธฐ๋กํด์ผ ์๋ณธ ๋ฐ์ดํฐ๋ฅผ ๋ณต์ํ ์ ์๋ค. ์ด ๋, ๊ฐ ๋ธ๋ก์ absmax ๋ฑ์ ์ ์ฅํ ๊ฒ์ด ๋ฐ๋ก ์์ํ ์์์ด๋ค.
4๋นํธ ์์ํ๋ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํฌ๊ฒ ์ค์ฌ์ฃผ์ง๋ง, ์์ํ ์์ ์์ฒด๋ ์ฌ์ ํ fp32๋ก ์ ์ฅ๋๋ค.
- ๊ฐ ๋ธ๋ก์ ํฌ๊ธฐ๋ฅผ 64๋ก ์ค์ ํ๋ฉด, ์์ํ ์์๋ 64๊ฐ์ ํ๋ผ๋ฏธํฐ ๋น 32๋นํธ๋ฅผ ์ฐจ์งํ๋ค.
์์ํ์์๋ฉ๋ชจ๋ฆฌ์ฌ์ฉ๋=32/64=0.5๋นํธ/ํ๋ผ๋ฏธํฐ
์ด๋ 4bit ํจ์จ์ฑ์ ์ ํ ์ํด. ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด์ , QLoRA ๋ double quantization์ ๋์
ํ์ฌ ์์ํ ์์๋ฅผ 8bit ๋ก ์ถ๊ฐ ์์ํํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ๋ ์ค์ธ๋ค.
Double Quantization ์ ๋์ ์๋ฆฌ
- 1์ฐจ ์์ํ
- ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ 4๋นํธ๋ก ์์ํํ๋ฉฐ, ๊ฐ ๋ธ๋ก์ ๋ํด ์์ํ ์์ cFP32(2)โ ๋ฅผ ์์ฑํ๋ค.
- 2์ฐจ ์์ํ
- 1์ฐจ ์์ํ ์์ c_{FP32}^{(2)} ๋ฅผ ์ถ๊ฐ๋ก ์์ํ ํ์ฌ fp8๋ก ๋ณํํ๋ค. cFP8(2)โ=Quantize(cFP32(2)โ)
- ์ด ๊ณผ์ ์์, FP32 ์์ํ ์์์ ํ๊ท ๊ฐ์ ๋นผ์ ๋์นญ์ ์์ํ๋ฅผ ์ํํ๋ค. (์์ํ ๋ฒ์๋ฅผ ๋ ํจ์จ์ ์ผ๋ก ํ์ฉํ๊ธฐ ์ํด ํ๊ท ๊ฐ์ ๋บ๋ค๊ณ ํ๋ค.)
- ์์ํ ์์์ ์์ํ ์์ cFP32(1)โ
- 2์ฐจ ์์ํ์ ๊ฒฐ๊ณผ์ธ cFP8(2)โ ๋ฅผ ๋ณต์ํ๊ธฐ ์ํด ๋ ๋ค๋ฅธ ์์ํ ์์ cFP32(1)โ ๊ฐ ํ์ํ๋ค.
Page Optimization
Paged Optimization์ GPU VRAM์ ์ฉ๋์ ์ด๊ณผํ๋ ๋๊ท๋ชจ ๋ชจ๋ธ์ ์คํํ ๋, VRAM๊ณผ CPU RAM์ ๋์ ์ผ๋ก ํ์ฉํ์ฌ OOM(Out-of-Memory) ์๋ฌ๋ฅผ ๋ฐฉ์งํ๋ ๊ธฐ์ ์ด๋ค. ๋ชจ๋ธ์ ์ผ๋ถ ํ๋ผ๋ฏธํฐ๋ ์ค๊ฐ ์ฐ์ฐ ๊ฒฐ๊ณผ๋ฅผ VRAM์์ RAM์ผ๋ก ์ฎ๊ฒจ ์ ์ฅํ๊ณ , ํ์ํ ๋ ๋ค์ VRAM์ผ๋ก ๊ฐ์ ธ์ค๋ ๋ฐฉ์์ผ๋ก ์๋ํ๋ค. ์ด๋ฅผ ํตํด VRAM ์ฌ์ฉ๋์ ์ต์ ํํ๊ณ , ์ ํ๋ ํ๋์จ์ด ํ๊ฒฝ์์๋ ๋๊ท๋ชจ ๋ชจ๋ธ์ ํจ์จ์ ์ผ๋ก ์คํํ ์ ์๋ค.
์ฆ, QLoRA ์ ํต์ฌ์ PLM Optimization ์ผ๋ก ์ธํด ๋ฐ์ํ๋ ์ค์ฐจ๋ฅผ LoRA ๊ฐ ํ์ตํ์ฌ ๋ณด์ ํ๋ ๊ฒ์ด๋ค.
QLoRA์ ํต์ฌ ์๋ฆฌ:
- ์์ํ๋ ๋ชจ๋ธ ๋๊ฒฐ:
- QLoRA๋ 4๋นํธ๋ก ์์ํ๋ ๋ฒ ์ด์ค ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ frozen ํ๋ค.
- ์ฆ, ๋ฒ ์ด์ค ๋ชจ๋ธ์ ๊ฐ์ค์น๋ ํ์ต๋์ง ์์ผ๋ฉฐ, optimization ์ํ ๊ทธ๋๋ก ์ ์ง
- LoRA ์ด๋ํฐ๋ฅผ ํตํด ๋ณด์ :
- optimization ๋ PLM์ด ํํํ์ง ๋ชปํ๋ ์ ๋ณด(=์์ํ ์ค์ฐจ)๋ฅผ LoRA ์ด๋ํฐ๊ฐ ํ์ตํ๋ค. ( LoRA ์ด๋ํฐ๋ low-rank matrix๋ก ๊ตฌ์ฑ๋์ด ์์ผ๋ฉฐ, ๋ชจ๋ธ์ ํน์ ์ ํ ๊ณ์ธต(Linear Layer)์ ์ถ๊ฐ๋๋ค. )
- LoRA ํ์ต ๊ณผ์ :
- ํ์ต ์ค, ๋ชจ๋ธ ์ถ๋ ฅ๊ณผ ์ค์ ์ ๋ต ์ฌ์ด์ ์ค์ฐจ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก LoRA ์ด๋ํฐ์ ํ๋ผ๋ฏธํฐ๋ง ์
๋ฐ์ดํธ ๋๋ค.
- LoRA ์ด๋ํฐ๋ ์์ํ๋ ๋ชจ๋ธ์ ์ถ๋ ฅ์ ๋ณด์ ํ์ฌ, ์์ํ๋ก ์ธํ ์ฑ๋ฅ ์์ค์ ์ต์ํํ๋ค.