# [코드리뷰] Instruction Tuning

안규원·2024년 6월 14일

AI

목록 보기
10/22

[개발환경]


라이브러리 설치

!pip3 install -q -U transformers==4.38.2
!pip3 install -q -U datasets==2.18.0
!pip3 install -q -U bitsandbytes==0.42.0
!pip3 install -q -U peft==0.9.0
!pip3 install -q -U trl==0.7.11
!pip3 install -q -U accelerate==0.27.2

모듈 import

import torch
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline, TrainingArguments
from peft import LoraConfig, PeftModel
from trl import SFTTrainer

[데이터셋]


데이터셋 size

dataset

DatasetDict({
    train: Dataset({
        features: ['instruction', 'output', 'url'],
        num_rows: 21155
    })
})

데이터셋 sample

dataset['train'][0]

{'instruction': '양파는 어떤 식물 부위인가요? 그리고 고구마는 뿌리인가요?',
 'output': '양파는 잎이 아닌 식물의 줄기 부분입니다. 고구마는 식물의 뿌리 부분입니다. \n\n식물의 부위의 구분에 대해 궁금해하는 분이라면 분명 이 질문에 대한 답을 찾고 있을 것입니다. 양파는 잎이 아닌 줄기 부분입니다. 고구마는 다른 질문과 답변에서 언급된 것과 같이 뿌리 부분입니다. 따라서, 양파는 식물의 줄기 부분이 되고, 고구마는 식물의 뿌리 부분입니다.\n\n 덧붙이는 답변: 고구마 줄기도 볶아먹을 수 있나요? \n\n고구마 줄기도 식용으로 볶아먹을 수 있습니다. 하지만 줄기 뿐만 아니라, 잎, 씨, 뿌리까지 모든 부위가 식용으로 활용되기도 합니다. 다만, 한국에서는 일반적으로 뿌리 부분인 고구마를 주로 먹습니다.',
 'url': 'https://kin.naver.com/qna/detail.naver?d1id=11&dirId=1116&docId=55320268'}

[Pretrained Model]

모델 로드

BASE_MODEL = "beomi/gemma-ko-2b"

model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)

추론 수행

파이프라인 정의

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)

프롬프트

prompt = "한국의 트로트라는 음악에 대해 알려줘"

생성결과 확인

outputs = pipe(
    prompt,
    do_sample=True,
    temperature=0.2,
    top_k=50,
    top_p=0.95,
    repetition_penalty=1.2,
    add_special_tokens=True
)

print(outputs[0]["generated_text"][len(prompt):])

서 감사합니다.

이번 글에서는 <strong>1970년대</strong>의 한국인 가수들 중에서도 특히 <strong>김영철</strong><strong>박정희</strong>를 소개하고자 합니다.

<h2>김영철</h2>

김영철은 1970년대 초반부터 활동을 시작하여, 1970년대 후반까지는 대부분의 노래가 팝송으로 불...

[Instruct Tuning]


BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

모델 로드

BASE_MODEL = "beomi/gemma-ko-2b"
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto", quantization_config=bnb_config)

generating_prompt

학습용 프롬프트 조정

def generate_prompt(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        prompt = f"### Instruction: {example['instruction'][i]}\n\n### Response: {example['output'][i]}<eos>"
        output_texts.append(prompt)
    return output_texts

LoraConfig

lora_config = LoraConfig(
    r=6,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

Trainer

trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    max_seq_length=512,
    args=TrainingArguments(
        output_dir="outputs",
#        num_train_epochs = 1,
        max_steps=3000,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        optim="paged_adamw_8bit",
        warmup_steps=0.03,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=100,
        push_to_hub=False,
        report_to='none',
    ),
    peft_config=lora_config,
    formatting_func=generate_prompt,
)

모델 저장

ADAPTER_MODEL = "lora_adapter_it"

trainer.model.save_pretrained(ADAPTER_MODEL)

합쳐서 하나의 모델로

model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', torch_dtype=torch.float16)
model = PeftModel.from_pretrained(model, ADAPTER_MODEL, device_map='auto', torch_dtype=torch.float16)

model = model.merge_and_unload()
model.save_pretrained('gemma-ko-2b-it')

[Fine-tuning Model]


FINETUNE_MODEL = "./gemma-2b-it-ko"

finetune_model = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)

pipe_finetuned = pipeline("text-generation", model=finetune_model, tokenizer=tokenizer, max_new_tokens=512)
prompt = "한국의 트로트라는 음악에 대해 알려줘"
formatted_prompt = f"### Response: {prompt}\n\n### Response:"

outputs = pipe_finetuned(
    formatted_prompt,
    do_sample=True,
    temperature=0.2,
    top_k=50,
    top_p=0.95,
    repetition_penalty=1.2,
    add_special_tokens=True
)
print(outputs[0]["generated_text"][len(formatted_prompt):])

0개의 댓글