โป ์ด ๊ธ์ ์๋ฌธ์ ์ด ๊ณณ์์ ํ์ธํ ์ ์์ต๋๋ค.
โป ๋ชจ๋ ๊ธ์ ๋ด์ฉ์ ํฌํจํ์ง ์์ผ๋ฉฐ ์๋กญ๊ฒ ๊ตฌ์ฑํ ๋ด์ฉ๋ ํฌํจ๋์ด ์์ต๋๋ค.
pipeline() ํจ์๋ฅผ ์ด์ฉํด ๋ชจ๋ธ์ ๊ตฌ์กฐ๋ฅผ ์ ํํ ๋ชจ๋ฅด๋๋ผ๋ pretrained ๋ชจ๋ธ์ ์ด์ฉํด ์์ฐ์ด์ฒ๋ฆฌ task๋ฅผ ์ํํ ์ ์์ต๋๋ค.
Model Hub์์ pretrained model๋ค์ ํ์ธํด๋ณผ ์ ์์ต๋๋ค. ๊ฐ ๋ชจ๋ธ๋ณ๋ก ์ํํ ์ ์๋ task๊ฐ ๋ชจ๋ ๋ค๋ฅด๋ฏ๋ก task์ ์ ํฉํ ๋ชจ๋ธ์ ์ฐพ์์ผํฉ๋๋ค.
task์ ์ ํฉํ model์ ์ฐพ์๋ค๋ฉด AutoModel, AutoTokenizer ํด๋์ค๋ฅผ ์ด์ฉํ์ฌ model๊ณผ model์ ์ฌ์ฉ๋๋ tokenizer๋ฅผ ๊ฐ๋จํ๊ฒ ๋ค์ด๋ก๋ํ ์ ์์ต๋๋ค.
!pip install transformers
ํ๊ตญ์ด fill-mask task๋ฅผ ์ํํ๊ธฐ์ํด BERT pretrained ๋ชจ๋ธ ์ค์์ bert-base-multilingual-cased๋ฅผ ๋ถ๋ฌ์ต๋๋ค.
from_pretrained()์ model ์ด๋ฆ์ ๋ฃ์ผ๋ฉด ์์ฝ๊ฒ pretrained model, tokenizer๋ฅผ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค.
from transformers import AutoModelForMaskedLM, AutoTokenizer
MODEL_NAME = 'bert-base-multilingual-cased'
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
๋จผ์ tokenizer๊ฐ ์ ์์ ์ผ๋ก ๋์ํ๋์ง ํ์ธํฉ๋๋ค.
fill-mask task๋ฅผ ์ํํ๋ ค๋ฉด text๋ด์ [MASK] special token์ด ํฌํจ๋์ด ์์ด์ผํฉ๋๋ค.
text = "์ด์์ ์ [MASK] ์ค๊ธฐ์ ๋ฌด์ ์ด๋ค."
tokenizer.tokenize(text)
['์ด', '##์', '##์ ', '##์', '[MASK]', '์ค', '##๊ธฐ์', '๋ฌด', '##์ ', '##์ด๋ค', '.']
BERT๋ WordPiece ๋ฐฉ์์ tokenization์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ##์ด๋ผ๋ ํน๋ณํ prefix๊ฐ ๋ถ์ด์๋ token๋ค์ ํ์ธํ ์ ์์ต๋๋ค.
pipeline()์ ์ด์ฉํด ํ๊ตญ์ด fill-mask task๋ฅผ ์ํํ๊ธฐ์ํ ํจ์๋ฅผ ๋ง๋ญ๋๋ค.
from transformers import pipeline
kor_mask_fill = pipeline(task='fill-mask', model=model, tokenizer=tokenizer)
kor_mask_fill ํจ์๋ฅผ ์ด์ฉํ์ฌ fill-mask task๋ฅผ ์ํํฉ๋๋ค.
text = "์ด์์ ์ [MASK] ์ค๊ธฐ์ ๋ฌด์ ์ด๋ค."
kor_mask_fill("์ด์์ ์ [MASK] ์ค๊ธฐ์ ๋ฌด์ ์ด๋ค.")
[{'score': 0.874712347984314,
'token': 59906,
'token_str': '์กฐ์ ',
'sequence': '์ด์์ ์ ์กฐ์ ์ค๊ธฐ์ ๋ฌด์ ์ด๋ค.'},
{'score': 0.0643644854426384,
'token': 9751,
'token_str': '์ฒญ',
'sequence': '์ด์์ ์ ์ฒญ ์ค๊ธฐ์ ๋ฌด์ ์ด๋ค.'},
{'score': 0.010954903438687325,
'token': 9665,
'token_str': '์ ',
'sequence': '์ด์์ ์ ์ ์ค๊ธฐ์ ๋ฌด์ ์ด๋ค.'},
{'score': 0.004647187888622284,
'token': 22200,
'token_str': '##์ข
',
'sequence': '์ด์์ ์์ข
์ค๊ธฐ์ ๋ฌด์ ์ด๋ค.'},
{'score': 0.0036106701008975506,
'token': 12310,
'token_str': '##๊ธฐ',
'sequence': '์ด์์ ์๊ธฐ ์ค๊ธฐ์ ๋ฌด์ ์ด๋ค.'}]
[MASK] ์๋ฆฌ์ ๋ค์ด๊ฐ token๋ค์ ๋ฆฌ์คํธ ํํ๋ก ๋ฐํํฉ๋๋ค.