[PyTorch x Kaggle] PyTorch Custom Dataset ์•Œ์•„๋ณด๊ธฐ ๐Ÿฅ

KwanHongยท2020๋…„ 12์›” 1์ผ
1

PyTorch x Kaggle

๋ชฉ๋ก ๋ณด๊ธฐ
1/1
post-thumbnail

๐Ÿ‘จ๐Ÿปโ€๐Ÿ’ป Introduction

PyTorch ํ”„๋ ˆ์ž„์›Œํฌ์— ์ต์ˆ™ํ•ด์ง€๊ธฐ ์œ„ํ•ด์„œ, kaggle dataset์„ ์ด์šฉํ•œ pytorch ๋…ธํŠธ๋ถ์„ ๋”ฐ๋ผํ•˜๊ณ  ์ฃผ์š” ํ† ํ”ฝ๋“ค์„ ํŒŒ๊ณ ๋“ค๊ธฐ๋กœ ํ•˜์˜€๋‹ค.

  • ์„ ์ • Topic
    • PyTorch custom dataset basics
    • ์˜ˆ์ œ dataset์ด ์•„๋‹Œ ์‹ค์ œ ํ’€๊ณ ์ž ํ•˜๋Š” ๋ฌธ์ œ์˜ custom dataset์ด ํ•„์š”

์•„๋ž˜๋Š” ํ•ด๋‹น ๋…ธํŠธ๋ถ์„ ์ €์žฅํ•œ Repository์ด๋‹ค.

๐Ÿงช Toxic comment dataset

Toxic comment dataset ๊ฐœ์š”

  • ์œ„ํ‚คํ”ผ๋””์•„ comment ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ
  • comment์˜ toxicity์— ๋”ฐ๋ผ ์‚ฌ๋žŒ์ด 6๊ฐ€์ง€ ํƒ€์ž…์˜ ํด๋ž˜์Šค๋กœ labeling
  • ๊ฐ toxicity ํƒ€์ž… ๋ณ„ ํ™•๋ฅ ์„ ์˜ˆ์ธกํ•˜๋Š” ๋ชจ๋ธ์„ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด ๋ชฉํ‘œ

Dataset ์˜ˆ์‹œ

  • comment_text๊ณผ ๋Œ€์‘ํ•˜๋Š” 6๊ฐœ์˜ toxicity label์ด ํ•œ ์Œ

โš™๏ธ Custom dataset skeleton

Toxic ๋ฐ์ดํ„ฐ์…‹์„ ๋‹ค๋ฃฐ custom Dataset ํด๋ž˜์Šค์ธ ToxicDataset์„ ์ƒ์„ฑํ•œ๋‹ค.

torch.utils.data.Dataset์„ ์ƒ์†๋ฐ›๊ณ , method๋ฅผ overrideํ•œ๋‹ค.

  • Pycharm์˜ Override Methods... ๊ธฐ๋Šฅ์„ ์‚ฌ์šฉํ•˜์—ฌ Dataset์—์„œ overrideํ•  methods์˜ hint ํ™•์ธ

PyTorch๋Š” dataset์„ ๋‘ ๊ฐ€์ง€ ํƒ€์ž…์œผ๋กœ ์ง€์›ํ•œ๋‹ค. (https://pytorch.org/docs/stable/data.html)

  • map-style datasets
  • iterable-stype datasets

ํ•ด๋‹น Dataset abstract class๋Š” map-style dataset์„ ๊ตฌํ˜„ํ•  ๊ฒฝ์šฐ์— ์ƒ์†๋ฐ›๋Š”๋‹ค.
์ด ํƒ€์ž…์˜ dataset์€ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์„ key๋กœ ์ ‘๊ทผ ๊ฐ€๋Šฅํ•œ map ํ˜•ํƒœ๋ผ๊ณ  ์ƒ๊ฐํ•˜๋ฉด ๋œ๋‹ค.

  • dataset[idx]์™€ ๊ฐ™์ด ์ ‘๊ทผํ•˜์—ฌ idx ๋ฒˆ์งธ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๋ฅผ ์ฝ์„ ์ˆ˜ ์žˆ๋‹ค.

์ด ํƒ€์ž…์˜ dataset์„ ๊ตฌ์„ฑํ•˜๊ธฐ ์œ„ํ•ด์„œ ๊ตฌํ˜„์ด ํ•„์š”ํ•œ method๋Š” ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

  • __init__: ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ ์ „์ฒ˜๋ฆฌ, csv ์ฝ์–ด์˜ค๊ธฐ, transform ์ ์šฉ๊ณผ ๊ฐ™์€ initial logic ๊ตฌํ˜„
  • __getitem__: ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋‹จ์ผ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ค๋Š” ํ•จ์ˆ˜. DataLoader์—์„œ indice ๋˜๋Š” keys๋กœ ๋‹จ์ผ ๋ฐ์ดํ„ฐ์— ์ ‘๊ทผํ•˜๊ฒŒ ๋œ๋‹ค.
    ๋‹ค๋ฃจ๊ณ ์ž
  • __len__: ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์˜ ์ด ๊ฐฏ์ˆ˜

โš’ Preparing custom dataset

Toxic comment dataset์˜ cuustom dataset ๊ตฌํ˜„์„ ์œ„ํ•œ ์ฃผ์š”์‚ฌํ•ญ์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

  • Dataset์€ ๋ชจ๋ธ์˜ ์ž…๋ ฅ์œผ๋กœ comment_text์™€ ์ด์— ๋Œ€์‘ํ•˜๋Š” label์„ __getitem__์—์„œ ๋ฐ˜ํ™˜
    - Dataset instance ์ƒ์„ฑ ์‹œ, tensor๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ instance ๋‚ด๋ถ€ ์ž๋ฃŒ๊ตฌ์กฐ์— ์œ ์ง€
  • BERT ๋ชจ๋ธ์— ์ž…๋ ฅํ•˜๊ธฐ ์œ„ํ•ด ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ๋ณ€ํ™˜ํ•  tokenizer ํ•„์š”

๐Ÿงฐ __init__ method

csv ํ˜•ํƒœ์˜ ๋ฐ์ดํ„ฐ ์…‹์„ ์•ž์„œ์„œ, pandas.DataFrame์œผ๋กœ load ํ•˜์˜€๋‹ค.

๐Ÿ”‘ __getitem__ method

์ดํ›„ DataLoader์—์„œ๋Š” Dataset์˜ __getitem__ ๋ฉ”์†Œ๋“œ๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ๋ฐ์ดํ„ฐ๋ฅผ ์ฝ์–ด๋“ค์ธ๋‹ค.

์œ„์—์„œ ๋ณธ Toxic comment dataset์—์„œ 0, 1 label ๋‘˜ ๋‹ค ์žˆ๋Š” 7๋ฒˆ์งธ ๋ฐ์ดํ„ฐ(index: 6)๋ฅผ ์˜ˆ์‹œ๋กœ ๋ถˆ๋Ÿฌ์™€ ์ถœ๋ ฅํ–ˆ๋‹ค.

๐Ÿ“ __len__ method

DataLoader์—์„œ __len__ ๋ฉ”์†Œ๋“œ๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ, ๋ฐ์ดํ„ฐ์…‹์˜ ์ „์ฒด ํฌ๊ธฐ๋ฅผ ์•Œ๊ณ  ์ด๋ฅผ ์ด์šฉํ•˜์—ฌ batch๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค.

References

profile
๋ณธ์งˆ์— ์ง‘์ค‘ํ•˜๋ ค๊ณ  ๋…ธ๋ ฅํ•ฉ๋‹ˆ๋‹ค. ๐Ÿ”จ

0๊ฐœ์˜ ๋Œ“๊ธ€