
์ ๋ธ๋ก๊ทธ https://seojune.site/post/huggingface-transformer ์์ ์ฝ๊ธฐ์ ์ต์ ํ๋์ด ์์ฑ๋์์ต๋๋ค.
ํ๊น
ํ์ด์ค๐ค transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ ๋ค์ํ ๋ฅ๋ฌ๋ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํฐ์
์ ๊ฐํธํ๊ฒ ์ฌ์ฉํ ์ ์์ด ๋๋ฆฌ ์ฌ์ฉ๋๋ ํจํค์ง์ด๋ค. transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋ฉ์ธ์ ๋ฐ๋ก ๋ชจ๋ธ ํ๋ จ์ ์ํ Trainer ํจ์๋ผ ํ ์ ์๋๋ฐ, ๋ชจ๋ธ ํ๋ จ์ ์ํ ์ ๋ง ๋ง์ ๊ธฐ๋ฅ๋ค์ ์ ๊ณตํ๊ณ ์๋ค. ๊ทธ๋ฐ๋ฐ ์ด๋ ๊ฒ ๋ง์ ๊ธฐ๋ฅ๋ค์ ์ ๋ถ ์ค์ ํ๋ ค๋ฉด ์์ฒญ๋๊ฒ ๋ง์ ํ๋ผ๋ฏธํฐ๋ค์ ์ธ์๋ก ๋ฐ์์ผ ํ๋ค. (ํ์๋ฅผ ํฌํจํด) ๋ง์ ์ฌ๋๋ค์ด ํ๊น
ํ์ด์ค ์ฌ์ฉ์ ์ด๋ ค์ํ๋ ์ด์ ์ด๋ค.
์ด ๊ธ์์๋ ๋จผ์ ํ๊น
ํ์ด์ค Trainer ํด๋์ค์ ์ฌ์ฉ๋ฒ์ ์์๋ณธ๋ค. ํนํ ์์ฃผ ์ฌ์ฉํ๋ ์ธ์๋ค๋ก ์ด๋ค ๊ฒ์ด ์๋์ง, ์ด๋ค ๋ฐฉ์์ผ๋ก Trainer์ ๋ฃ์ด์ค์ผ ํ๋์ง๋ฅผ ์์๋ณธ๋ค. ๋, GPT-2 ๋ชจ๋ธ์ IMDB ๋ฐ์ดํฐ์
์ผ๋ก fine-tuning์์ผ์ ์ํ ๋ฆฌ๋ทฐ ๋ถ๋ฅ์ ์ค์ ๋ก ์ ์ฉํด๋ณธ๋ค.
ํ๊น ํ์ด์ค(HuggingFace) ๐ค๋ ๋จธ์ ๋ฌ๋/๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํ๋ จ, ๊ณต์ ํ๊ณ ๋ฐฐํฌํ๊ธฐ ์ํ ์๋น์ค๋ค์ ์ ๊ณตํ๋ ํ์ฌ์ด์ ์ปค๋ฎค๋ํฐ์ด๋ค. ํ๊น ํ์ด์ค์์ ์ ๊ณตํ๋ ์๋น์ค๋ ํฌ๊ฒ ๋ ๊ฐ๋ก ๋๋ ์ ์๋ค.
๋ชจ๋ธ๊ณผ ๋ฐ์ดํฐ์
๋ฑ์ ๊ณต์ ํ ์ ์๋ HuggingFace Hub
์ฌ์ ํ์ต๋ ๋ชจ๋ธ๋ค๊ณผ ๋ฐ์ดํฐ์
๋ค์ด ์ ์ ๋ฆฌ๋์ด ์๊ณ , ์คํ์์ค๋ก ์ฝ๊ฒ ๋ค์ด๋ฐ๊ณ ์ฌ์ฉํ ์ ์๋๋ก ๋์ด์๋ค.
๋ชจ๋ธ์ ํ๋ จ์ํค๊ธฐ ์ํ ํ์ด์ฌ์ transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ
ํ๋ธ์ ๊ณต๊ฐ๋ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํฐ์
์ ์ฝ๊ฒ ๊ฐ์ ธ์ฌ ์ ์๊ณ , ๋ชจ๋ธ ํ๋ จ ๋ํ ์ฝ๊ฒ ํ ์ ์๋๋ก API๋ฅผ ์ ๊ณตํ๋ค.
์ด ๊ธ์์ ๋ค๋ฃจ๋ ๊ฒ์ ์ด ์ค ์ ์๋ก, ํ๊น
ํ์ด์ค์์ ์ ๊ณตํ๋ transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ฌ์ฉ๋ฒ์ ์์๋ณด๊ฒ ๋ค.
Trainer์ TrainingArguments์ ๋งค๊ฐ๋ณ์๋คํ๊น
ํ์ด์ค Trainer API๋ฅผ ์ด์ฉํด์ ๋ชจ๋ธ์ ํ๋ จํ ๋, ๋งค๊ฐ๋ณ์๋ค ์ค ์ผ๋ถ๋ TraininingArguments๋ก, ์ผ๋ถ๋ Trainer์ ๋ฃ์ด์ฃผ์ด์ผ ํ๋ค. ์์๋ฅผ ๋ค์๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
from transformers import Trainer, TrainingArguments
training_arguments = TrainingArguments(
output_dir='./results',
evaluation_strategy="epoch",
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=32,
learning_rate=3e-5,
logging_strategy="epoch",
load_best_model_at_end=True,
save_strategy="epoch",
metric_for_best_model="accuracy",
)
trainer = Trainer(
model=AutoModelForSequenceClassification.from_pretrained(โbert-base-uncasedโ),
train_dataset=ds_train,
eval_dataset=ds_test,
args=training_arguments,
compute_metrics=compute_metrics,
)
trainer.train()
์ด๋ ๊ฒ TrainingArguments ๊ฐ์ฒด๋ฅผ ํ๋ ๋ง๋ ํ, ์ด๋ฅผ ๋ค์ Trainer์ args ๋งค๊ฐ๋ณ์๋ก ๋ฃ์ด์ฃผ๋ ์์ผ๋ก training์ ํ์ํ ์ ๋ณด๋ค์ ์๋ ค์ฃผ์ด์ผ ํ๋ค. ๋จผ์ TrainingArguments๊ฐ ๋ฐ๋ ๋งค๊ฐ๋ณ์๋ค๋ถํฐ ์ดํด๋ณด์.
TrainingArguments๊ณต์ API ๋ฌธ์์์ TrainingArguments๋ฅผ ์ฐพ์๋ณด๋ฉด argument๋ค์ ๋งค์ฐ ๊ธด ๋ชฉ๋ก์ด ๋์จ๋ค. ์ด ์ค ์์ฃผ ์ฐ์ด๋ ๋ช ๊ฐ๋ฅผ ์ ๋ฆฌํด๋ณด์๋ค.
output_dir: str: ์ ์ผํ required argument๋ก, ํ๋ จ๋ ๋ชจ๋ธ๊ณผ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ด๋์ ์ ์ฅํ ์ง๋ฅผ ์๋ฏธํ๋ค. ํ์ผ์์คํ
์์ ๊ฒฝ๋ก๋ก ์ง์ ํ๊ฑฐ๋(e.g. โ./resultsโ) ํ๊น
ํ์ด์ค ํ๋ธ์ ์ ์ฅํ repository ์ด๋ฆ์ผ๋ก ์ค์ ํ ์๋ ์๋ค (e.g. โbert-uncased-imdb-finetunedโ). ํ์์ ๊ฒฝ์ฐ ํธ๋ ์ด๋ ํ trainer.push_to_hub()๋ฅผ ํ๋ฉด ์๋์ผ๋ก ํ๋ธ์ ์
๋ก๋๋๋ค. overwrite_output_dir: bool: True๋ก ์ค์ ์, output_dir์ ์ด๋ฏธ ํ์ผ์ด ์กด์ฌํ๋ ๊ฒฝ์ฐ์๋ ๋ฎ์ด์ฐ๊ธฐ๋ฅผ ํ๋ค.num_train_epochs: ํ๋ จํ ์ํฌํฌ(epoch)์ ์์ด๋ค.per_device_train_batch_size: train ์์ batch size๋ฅผ ์ง์ ํด์ค๋ค.per_device_eval_batch_size: evaluation ์์ batch size๋ฅผ ์ง์ ํด์ค๋ค.learning_rate: learning rate(ํ์ต๋ฅ )์ ์ง์ ํด์ค๋ค.lr_scheduler_type: learning rate scheduler๋ฅผ ์ฌ์ฉํ๊ณ ์ถ์ ๊ฒฝ์ฐ ์ง์ ํด์ค ์ ์๋ค. default๋ linear์ด๊ณ , constant, cosine, cosine_with_restarts, polynomial, constant_with_warmup ๋ฑ์ ์ฌ์ฉํ ์ ์๋ค. * constant๊ฐ ์๋ LR scheduler๋ฅผ ์ฌ์ฉํ ์ ์ถ๊ฐ์ ์ผ๋ก ํ๋ผ๋ฏธํฐ๋ฅผ ๋ฃ์ด์ฃผ์ด์ผ ํ๋ค. weight_decay: L2 weight decay๋ฅผ ์ค์ ํ๋ค.logging_strategy: ๋ก๊ทธ๋ฅผ ์ด๋ป๊ฒ ๋จ๊ธธ์ง ์ค์ ํ๋ค. ๊ธฐ๋ณธ๊ฐ์ steps๋ก, no๋ก ์ค์ ํ๋ฉด ๋ก๊ทธ๋ฅผ ๋จ๊ธฐ์ง ์์ผ๋ฉฐ epoch๋ก ์ค์ ํ๋ฉด ํ ์ํฌํฌ๊ฐ ๋๋ ๋๋ง๋ค, steps๋ก ์ค์ ํ๋ฉด ๋งค logging_steps๋ง๋ค ๋ก๊ทธ๋ฅผ ๋จ๊ธฐ๊ฒ ๋๋ค.logging_strategy=โstepsโ๋ก ์ค์ ์ logging_steps๋ฅผ ๊ฐ์ด ์ค์ ํด์ฃผ์ด์ผ ํ๋ค. logging_steps๋ ์ ์๋ฅผ ๋ฃ์ด์ค ์๋ ์์ง๋ง 0์์ 1 ์ฌ์ด์ float ๊ฐ์ผ๋ก ์ง์ ํ ์๋ ์๋๋ฐ, ์ด ๊ฒฝ์ฐ ์ ์ฒด training step์ ์ด๋ฅผ ๊ณฑํ ๊ฐ์ ์ฌ์ฉํ๋ค.save_strategy: ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ธ์ ์ ์ฅํ ์ง ์ค์ ํ๋ค. ๋ง์ฐฌ๊ฐ์ง๋ก no, steps, epoch ์ค ํ๋๋ฅผ ๊ณ ๋ฅผ ์ ์์ผ๋ฉฐ steps๋ก ์ค์ ๋๋ฉด save_steps๋ฅผ ๊ฐ์ด ์ค์ ํด์ฃผ์ด์ผ ํ๋ค.save_total_limit: ์ต๋๋ก ์ ์ฅํ ์ ์๋ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ์ ์๋ฅผ ์ง์ ํด์ค๋ค.evaluation_strategy: evaluation์ ์ธ์ ์ํํ ์ง๋ฅผ ๊ฒฐ์ ํ๋ค. ๋ง์ฐฌ๊ฐ์ง๋ก no, steps, epoch ์ค ํ๋๋ฅผ ๊ณ ๋ฅผ ์ ์์ผ๋ฉฐ steps๋ก ์ค์ ๋๋ฉด evaluation_steps๋ฅผ ๊ฐ์ด ์ค์ ํด์ฃผ์ด์ผ ํ๋ค.use_cpu: True๋ก ์ค์ ํ ์ ์ฌ์ฉ๊ฐ๋ฅํ GPU๊ฐ ์์ด๋ CPU์์ ์คํํ๋ค.seed: ๋๋ค ์๋๋ฅผ ์ค์ ํด์ค๋ค. ๋ง์ ๊ฒฝ์ฐ 42๋ฅผ ์ฌ์ฉํ๋ค. fp16: fp16 mixed-precision์ ์ฌ์ฉํ ์ง ์ฌ๋ถ๋ฅผ True/False๋ก ์ค์ ํด์ค๋ค. disable_tqdm: False๋ก ์ค์ ์, ํ ๋์ tqdm์ ์ฌ์ฉํด์ progress bar๋ฅผ ํ์ํด์ค๋ค.load_best_model_at_end: True๋ก ์ค์ ํ๋ฉด ํ๋ จ์ด ๋๋ฌ์ ๋ ๊ฐ์ฅ ์ฑ๋ฅ์ด ์ข์ ๋ชจ๋ธ์ ๋ก๋ํด์ค๋ค.metric_for_best_model: ์ด๋ โ์ฑ๋ฅ์ด ์ข๋คโ๋ ๊ฒ์ ๊ธฐ์ค์ ๋ฌด์์ผ๋ก ์ผ์์ง ์ง์ ํด์ค๋ค. ํ์ ํ compute_metrics ํจ์์์ ๋ฐํํ๋ metric ์ค ํ๋์ ์ด๋ฆ์ string์ผ๋ก ๋ฃ์ด์ฃผ๋ฉด ๋๋ค (e. g. โaccuracyโ)greater_is_better: ํด๋น metric์ด ๋์ ์๋ก ์ข์ ๊ฒ์ธ์ง, ๋ฎ์์๋ก ์ข์ ๊ฒ์ธ์ง๋ฅผ ์๋ ค์ค๋ค.TrainerTrainingArguments์์ ํ์ต์ ์ํฌ ๋ ์๋ ค์ค์ผ ํ ์ธ๋ถ์ฌํญ๋ค์ ๋ฃ์ด์ฃผ์๋ค๋ฉด, Trainer์์๋ ์ข ๋ ๊ธฐ๋ณธ์ ์ธ ๊ตต์งํ ์ ๋ณด๋ค์ ์๋ ค์ฃผ์ด์ผ ํ๋ค.
model: transformer ๋ชจ๋ธ ๊ฐ์ฒด๋ PyTorch์ nn.Module ๊ฐ์ฒด๋ฅผ ๋ฃ์ด์ฃผ๋ฉด ๋๋ค.model ๋์ ์, ์๋ก์ด ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ํ๋ ๋ง๋ค์ด ๋ฐํํ๋ ํจ์์ธ model_init์ ์ ๊ณตํด์ค ์๋ ์๋ค. args: ์์ ์๊ฐํ TrainingArguments ๊ฐ์ฒด์ด๋ค. data_collator: train/evaluation dataset์ ์๋ ์์๋ค์ list๋ฅผ ๋ฌถ์ด์ batch๋ก ๋ง๋ค์ด์ฃผ๊ธฐ ์ํด ์ฌ์ฉ๋๋ ํจ์์ด๋ค. tokenizer๋ฅผ ์ง์ ํด์ฃผ์ง ์์ ๊ฒฝ์ฐ default_data_collator()๊ฐ ์ฌ์ฉ๋๊ณ , ์ง์ ํด์ค ๊ฒฝ์ฐ๋ DataCollatorWithPadding์ ์ธ์คํด์ค๊ฐ ์ฌ์ฉ๋๋ค.train_dataset: ๊ฐ์ฅ ์ค์ํ๋ค๊ณ ํ ์ ์๋ ํ๋ จ์ฉ ๋ฐ์ดํฐ์
์ ์ง์ ํด์ค๋ค. transformer์ dataset.Dataset์ผ์๋, PyTorch์ Dataset์ผ ์๋ ์๋ค.eval_dataset: Evaluation์ฉ ๋ฐ์ดํฐ์
์ ์ง์ ํด์ค๋ค. ํ์์ train_dataset๊ณผ ๊ฐ๋ค.tokenizer: ๋ฐ์ดํฐ๋ฅผ ์ ์ฒ๋ฆฌํ๊ธฐ ์ํ ํ ํฌ๋์ด์ ๋ฅผ ์ง์ ํด์ค๋ค. compute_metrics: Evaluation์์ metric๋ค์ ๊ณ์ฐํด์ฃผ๊ธฐ ์ํ ํจ์์ด๋ค. ํจ์์ ์
๋ ฅ๊ณผ ์ถ๋ ฅ์ ํน์ ํ์์ ๋ฐ๋ผ์ผ๋ง ํ๋๋ฐ, ์ด๋ ๋ค์์ ๋ง์ ์ค๋ช
ํ๊ฒ ๋ค.optimizers: ํ๋ จ์ ์ฌ์ฉํ optimizer์ LR scheduler๋ฅผ ์ง์ ํด์ค๋ค. torch.optim.Optimizer ๊ฐ์ฒด์ torch.optim.lr_scheduler.LambdaLR ๊ฐ์ฒด์ tuple์ ์๊ตฌํ๋ค. ์๋ฌด๊ฒ๋ ์
๋ ฅํ์ง ์์ผ๋ฉด AdamW๋ฅผ ์ฌ์ฉํ๋ค.์์์ ์ค๋ช
ํ TrainingArguments์ Trainer์ ๋งค๊ฐ๋ณ์๋ค ์ค์์๋ ์ซ์๋ ๋ฌธ์์ด์ด ์๋, ํจ์๋ ํน์ ํด๋์ค์ ๊ฐ์ฒด๋ฅผ ์๊ตฌํ๋ ๊ฒ๋ค์ด ์๋ค. ์ด ๋ ์ธ์๋ก ์ฃผ์ด์ง๋ ํจ์๋ ๋น์ฐํ ํน์ ํ ์
๋ ฅ๊ณผ ์ถ๋ ฅ ํ์์ ๋ฐ๋ผ์ผ๋ง ํ ๊ฒ์ด๊ณ , ๊ฐ์ฒด๋ ๋น์ฐํ ์ ํด์ ธ ์๋ ํน์ ํด๋์ค์ ๊ฐ์ฒด์ฌ์ผ๋ง ํ ๊ฒ์ด๋ค. ๊ทธ๋ฌ์ง ์์ผ๋ฉด ์๋ฌ๊ฐ ๋ฐ์ํ๊ฒ ๋๋ค. ์ค๋ช
ํ ๊ฒ ์ธ์๋ค ์ค์์๋ data_collator์ compute_metrics๊ฐ ์ด๋ฌํ ๊ฒฝ์ฐ์ ํด๋นํ๋๋ฐ, ๊ฐ๊ฐ ์ด๋ค ํ์์ ๋ฐ๋ผ์ผ ํ๋์ง ๊ฐ๋ตํ๊ฒ ์์๋ณด์.
data_collatordata_collator๋ ์์ ์ค๋ช
ํ๋ฏ์ด ๋ฐ์ดํฐ๋ฅผ batch๋ก ๋ฌถ์ด model์ ์ ๋ฌํ ์ ์๋ ํํ๋ก ๋ง๋ค์ด์ฃผ๋ฉฐ, DataCollator ํด๋์ค์ ์ธ์คํด์ค๊ฐ ๋๋๋ก ์ ํด์ ธ ์๋ค. DataCollator ํด๋์ค๋ ์ฌ๋ฌ ์์ ํด๋์ค๋ฅผ ๊ฐ์ง๋๋ฐ, ์ด๋ค task๋ฅผ ์ํํ๋๋์ ๋ฐ๋ผ ๋ค๋ฅธ ๊ฒ์ ์ฌ์ฉํ๋ฉด ๋๋ค.
DataCollatorWithPadding: ์
๋ ฅ๋ ์ํ์ค๋ฅผ ๊ธธ์ด๊ฐ ๋์ผํด์ง๋๋ก ํจ๋ฉํ์ฌ batch๋ฅผ ๋ง๋ ๋ค. ํ
์คํธ ๋ถ๋ฅ์ ๊ฐ์ ์์
์ ์ฌ์ฉํ๋ค.DataCollatorForSeq2Seq: sequence-to-sequence ์์
(e.g. ๋ฒ์ญ, ์์ฝ)์ ์ํ ํด๋์ค์ด๋ค. ์ด ๊ฒฝ์ฐ source์ target sequence๋ฅผ ๋ชจ๋ ํจ๋ฉํด์ค๋ค.DataCollatorForLanguageModeling: masked language modeling(MLM)๊ณผ ๊ฐ์ด ์ธ์ด ๋ชจ๋ธ๋ง(language modeling) task์ ์ฌ์ฉ๋๋ค.mlm_probability๋ก ํ ํฐ์ ๋ง์คํนํ ํ๋ฅ ์ ์ง์ ํด์ค ์ ์๋ค.DataCollatorForTokenClassification: ํ ํฐ ๋ถ๋ฅ ์์
(์: NER)์ ์ํ ๋ฐฐ์น๋ฅผ ๋ง๋ค์ด์ค๋ค.DefaultDataCollator: ํน๋ณํ ์ ์ฒ๋ฆฌ ์์ด ๋ฐ์ดํฐ๋ฅผ ๋ฐฐ์น๋ก ๋ฌถ๊ธฐ๋ง ํ๋ data collator๋ก, Trainer ํด๋์ค์ ๊ธฐ๋ณธ๊ฐ์ด๋ค.compute_metricsEvaluation ์์ ์ฌ์ฉํ metric์ ๊ณ์ฐํด์ฃผ๋ compute_metrics๋ EvalPrediction ๊ฐ์ฒด๋ฅผ ์
๋ ฅ์ผ๋ก ๋ฐ์ dictionary๋ฅผ ์ถ๋ ฅํ๋ ํจ์์ด๋ค. ์ด๋ EvalPrediction์ ์ผ์ข
์ named tuple์ผ๋ก, predictions์ label_ids๋ผ๋ ๋ ๊ฐ์ ์์ฑ์ ํ์์ ์ผ๋ก ๊ฐ๋๋ค. ์ด๋ฆ์์ ์ ์ ์๋ฏ์ด, predictions๋ ๋ชจ๋ธ์ ์์ธก๊ฐ, label_ids๋ ๋ฐ์ดํฐ์
์ด ์ ๊ณตํ๋ ์ ๋ต์ ์๋ฏธํ๋ค. ์ด ๋์ ์ฌ๋ฌ ๊ฐ์ metric์ ์ฌ์ฉํด ๋น๊ตํ๋ ๊ฒ์ด ๋ฐ๋ก compute_metrics์ ์ญํ ์ด๋ผ๊ณ ํ ์ ์๋ค. ๊ณ์ฐ์ ์๋ฃํ๋ฉด metric์ ์ด๋ฆ์ key๋ก, ๊ทธ ๊ฐ์ value๋ก ํ๋ dictionary๋ฅผ ๋ฐํํด์ผ ํ๋ค.
๋ค์์ compute_metrics๋ฅผ ์์ฑํ ์์์ด๋ค.
import numpy as np
from datasets import load_metric
from transformers import TrainingArguments, Trainer
# metric๋ค์ ๊ฐ์ ธ์ค๊ธฐ
accuracy_metric = load_metric("accuracy")
f1_metric = load_metric("f1")
def compute_metrics(eval_pred):
predictions, label_ids = eval_pred
# predictions, label_ids = eval_pred.predictions, eval_pred.label_ids
# ์ ๊ฐ์ด ์ ๊ทผํ ์๋ ์๋ค.
preds = predictions.argmax(axis=1)
accuracy = accuracy_metric.compute(predictions=preds, references=label_ids)
f1 = f1_metric.compute(predictions=preds, references=label_ids, average="weighted")
return {
"accuracy": accuracy["accuracy"],
"f1": f1["f1"],
}
Trainer API๋ transformers ๋ชจ๋ธ์ ํ๋ จ์ํค๋ ๊ฒ์ ์ต์ ํ๋์ด ์์ง๋ง, ์ฌ์ฉ์๊ฐ PyTorch๋ก ๊ตฌํํ ์ปค์คํ ๋ชจ๋ธ์ ํ๋ จ์ํฌ ๋๋ ์ฌ์ฉํ ์ ์๋ค. Trainer API documentation์์๋ ์ปค์คํ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์ฃผ์ํด์ผ ํ ์ ๋ค์ ์ธ๊ธํ๊ณ ์๋ค.
labels argument๊ฐ ์ฃผ์ด์ง ๊ฒฝ์ฐ, loss๋ฅผ ๊ณ์ฐํ์ฌ (๋ชจ๋ธ์ด tuple์ ๋ฆฌํดํ ๊ฒฝ์ฐ) tuple์ ์ฒซ ๋ฒ์งธ ์์๋ก ๋ฆฌํดํด์ผ ํ๋ค.label_names ๋งค๊ฐ๋ณ์๋ก label์ ์ด๋ฆ์ ๋ด์ list๋ฅผ ๋ฃ์ด์ฃผ์ด์ผ ํ๋ฉฐ ์ฌ๋ฌ ๊ฐ์ label ์ค ์ด๋ฆ์ด ๊ทธ๋ฅ label์ธ ๊ฒ์ ์์ผ๋ฉด ์๋๋ค. ์ด์ ํ๊น
ํ์ด์ค Trainer ์ฌ์ฉ๋ฒ์ ๋ฐฐ์ ์ผ๋, ์ค์ LLM์ ํ๋ จํด๋ณด์. GPT-2๋ฅผ IMDb ์ํ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ์
์์ fine-tuningํ์ฌ, ์์ฑ๋ ์ํ๋ฆฌ๋ทฐ๊ฐ ์ํ๋ฅผ ์ข๊ฒ ํ๊ฐํ๋์ง, ๋์๊ฒ ํ๊ฐํ๋์ง ๋ถ๋ฅํ๋ ๊ฐ๋จํ ์์
์ ์ํํ๋๋ก ํด๋ณธ๋ค.
๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ค์น
%%capture
!pip install -U datasets transformers accelerate
ํ์ํ ํจํค์ง๋ค์ ์ค์นํด์ค๋ค. Colab ๊ธฐ์ค์ผ๋ก ์ ํจํค์ง๋ค์ ์ด๋ฏธ ์ค์น๋์ด ์์ง๋ง, ํ์ฌ(2024๋
5์ 15์ผ ๊ธฐ์ค) ๋ฒ์ ๋ฌธ์ ์ธ์ง ์ด ์์
์ ํด์ฃผ์ง ์์ผ๋ฉด ํ๋ จ์ด ๋์ง ์๋๋ค. ์ด์ธ์๋ ์ฌ์ฉํ๊ฒฝ์ ๋ฐ๋ผ ์ฝ๋๋ฅผ ์คํํ๋ฉด์ ํจํค์ง๊ฐ ์ค์น๋์ด ์์ง ์๋ค๊ณ ๋์ฌ ๊ฒฝ์ฐ pip์ผ๋ก ์ค์นํด์ฃผ๋ฉด ๋๋ค.
๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ถ๋ฌ์ค๊ธฐ
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_ckpt = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt)
ํ๊น ํ์ด์ค ํ๋ธ์์ GPT-2๋ฅผ ์ฐพ์์ ์ํฌํธํด์ฃผ์๋ค. ๋งํฌ์์ Use in Transformers ๋ฒํผ์ ๋๋ฌ ๊ฐํธํ๊ฒ ๋ถ๋ฌ์ฌ ์ ์๋ค.
๋ฐ์ดํฐ์ ๋ถ๋ฌ์ค๊ธฐ
from datasets import load_dataset
dataset = load_dataset("stanfordnlp/imdb")
ds_train = dataset['train'].shuffle().select(range(10000))
ds_test = dataset['test'].shuffle().select(range(2500))
IMDb ๋ฐ์ดํฐ์ ์ ๊ฐ์ ธ์จ๋ค. ๋ฐ์ดํฐ์ ๋ํ ํ๊น ํ์ด์ค ํ๋ธ์์ ์ฝ๊ฒ ๋ถ๋ฌ์ฌ ์ ์๋๋ก ์ ๊ณตํ๊ณ ์๋ค. (๋งํฌ) ์ค์ ๋ฐ์ดํฐ์ ์ train๊ณผ test set์ด ๊ฐ๊ฐ 25000๊ฐ์ ๋ฐ์ดํฐ๋ก ์ด๋ฃจ์ด์ ธ ์์ง๋ง, ๋น ๋ฅธ ํ์ต์ ์ํด์ ๊ฐ๊ฐ 10000๊ฐ์ 2500๊ฐ๋ง ์ฌ์ฉํ๊ฒ ๋ค.
Metric ์ ์ํ๊ธฐ
import numpy as np
from datasets import load_metric
from transformers import TrainingArguments, Trainer
accuracy_metric = load_metric("accuracy")
f1_metric = load_metric("f1")
def compute_metrics(eval_pred):
predictions, label_ids = eval_pred.predictions, eval_pred.label_ids
predictions = predictions.argmax(axis=1)
accuracy = accuracy_metric.compute(predictions=predictions, references=label_ids)
f1 = f1_metric.compute(predictions=predictions, references=label_ids, average="weighted")
return {
"accuracy": accuracy["accuracy"],
"f1": f1["f1"],
}
Evaluation ์ ์ฌ์ฉํ metric๋ค์ ์ง์ ํด์ฃผ๊ธฐ ์ํด์ compute_metrics ํจ์๋ฅผ ์ ์ํด์ค๋ค. ์ฌ๊ธฐ์์๋ accuracy์ F1 score๋ฅผ ์ฌ์ฉํ๋ค. ํจ์๊ฐ ํ์์ ์ ๋ง๋์ง ์ฃผ์ํด์ผ ํ๋ค.
ํจ๋ฉ ํ ํฐ ์ง์ ํ๊ธฐ, Data Collator ์ ์ํ๊ธฐ
from transformers import DataCollatorWithPadding
model.config.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token_id = tokenizer.eos_token_id
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, max_length=256)
๋ฐฐ์น๋ฅผ ๋ง๋ค์ด์ ๊ธธ์ด๊ฐ ๋ค๋ฅธ ์ฌ๋ฌ ๊ฐ์ ์ํ์ค๋ฅผ ํ๊บผ๋ฒ์ ์ฒ๋ฆฌํ๋ ค๋ฉด ํจ๋ฉ์ด ์ด๋ฃจ์ด์ ธ์ผ ํ๋ค. ์ด๋ฅผ ์ํํ๊ธฐ ์ํ data_collator๋ฅผ ์ ์ํ๋ค. ๋ํ, model๊ณผ tokenizer์๊ฒ ํจ๋ฉ ํ ํฐ์ด ๋ฌด์์ธ์ง๋ฅผ ์๋ ค์ค์ผ ํ๋๋ฐ, ์ผ๋ฐ์ ์ผ๋ก ์์ ๊ฐ์ด EOS(end of sequence) ํ ํฐ๊ณผ ๋์ผํ๊ฒ ์ง์ ํด์ค๋ค.
Trainer ์ ์ํ๊ธฐ
์ด์ ์ ์ํ ๋ณ์๋ค์ ๋ชจ๋ ๋ชจ์ Trainer๋ฅผ ์ ์ํด์ค ์ฐจ๋ก์ด๋ค.
from transformers import Trainer, TrainingArguments
training_arguments = TrainingArguments(
output_dir='./results',
evaluation_strategy="epoch",
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=32,
learning_rate=3e-5,
logging_strategy="epoch",
load_best_model_at_end=True,
save_strategy="epoch",
metric_for_best_model="accuracy",
)
trainer = Trainer(
model=model,
train_dataset=ds_train,
data_collator=data_collator,
eval_dataset=ds_test,
args=training_arguments,
compute_metrics=compute_metrics,
)
๋ฐฐ์ ๋๋๋ก TrainingArguments์ Trainer๋ฅผ ์ฐจ๋ก๋๋ก ์ ์ํด์ฃผ๊ณ , ํ์ํ ๋งค๊ฐ๋ณ์๋ค์ ํ๋์ฉ ๋ฃ์ด์ฃผ์. ๊ผญ ์ด ๊ธ์ ์๋๋๋ก ํ ํ์ ์์ด, ์ธ์๋ค์ ํ๋์ฉ ๋ฐ๊ฟ๋ณด๊ฑฐ๋ ๋ค๋ฅธ ๋งค๊ฐ๋ณ์๋ค์ ๋ฃ์ด๋ณด๋ ์์ผ๋ก ์ฝ๋๋ฅผ ๋ฐ๊ฟ๋ณด๋ฉด ์ดํด์ ๋์์ด ๋ ๊ฒ์ด๋ค.
ํ๋ จํ๊ธฐ
trainer.train()

์์ ๊ฐ์ด ํ๋ จ์ด ์ ์งํ๋๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
โHugging Faceโ, Wikipedia
Trainer API documentation
# ๋ฐ์ดํฐ ๊ณผํ์๋ค์ด ์ซ์ 42๋ฅผ ์ข์ํ๋ ์ด์
Github transformers EvalPrediction ์์ค์ฝ๋
transformers.utils.ModelOutput documentation
lr_scheduler_type ๋ํดํธ๋ linear ์ ๋๋ค~!