CatBoost

์•ˆ์„ฑ์ธยท2022๋…„ 4์›” 7์ผ
1

๐Ÿ“– ์ด๋ฒˆ ํฌ์ŠคํŒ…์—์„œ๋Š” Boosting ๊ณ„์—ด์˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ค‘์—์„œ GBM(Gradient Boosting Machines) ์•Œ๊ณ ๋ฆฌ์ฆ˜์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๋งŒ๋“ค์–ด์ง„ CatBoost ์•Œ๊ณ ๋ฆฌ์ฆ˜์— ๋Œ€ํ•ด ์•Œ์•„๋ณด๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
CatBoost๋Š” 2017๋…„ ๋…ผ๋ฌธ์—์„œ ์†Œ๊ฐœ๋˜์—ˆ์œผ๋ฉฐ ํ˜„์žฌ๊นŒ์ง€ ํ˜„์—…์—์„œ๋„ ํ™œ๋ฐœํ•˜๊ฒŒ ์‚ฌ์šฉ๋˜๊ณ  ์žˆ๋‹ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. CatBoost์— ๋Œ€ํ•œ ๊ฐœ๋…๋“ค์„ ์‚ดํŽด๋ด…์‹œ๋‹ค.


[CatBoost๋ž€?]

  • GBM์˜ ์น˜๋ช…์ ์ธ ๋ฌธ์ œ์  ์ค‘ ํ•˜๋‚˜๋กœ ๊ณผ์ ํ•ฉ ๋ฌธ์ œ๊ฐ€ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ณผ์ ํ•ฉ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋ฉด์„œ ๋™์‹œ์— ๊ธฐ์กด GBM๊ณ„์—ด์˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ธ XGBoost, LightGBM ์•Œ๊ณ ๋ฆฌ์ฆ˜๋ณด๋‹ค ํ•™์Šต ์†๋„๋ฅผ ๊ฐœ์„ ํ•˜๋Š” ์žฅ์ ์„ ์•ž์„ธ์›Œ ๊ฐœ๋ฐœ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

    • GBM์€ ์ตœ์ดˆ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šตํ•  ๋•Œ๋งŒ ์˜ˆ์ธก๊ฐ’์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•˜๊ณ  ๊ทธ ์ดํ›„์˜ ๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šตํ•  ๋•Œ๋Š” ์˜ˆ์ธก๊ฐ’์„ ํ™œ์šฉํ•ด์„œ ๊ณ„์‚ฐํ•œ ์ž”์ฐจ์—๋งŒ ํฌ์ปค์Šค๋ฅผ ๋งž์ถ”์–ด์„œ ํ•™์Šตํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.
    • ํ•˜์ง€๋งŒ ์ด๋ ‡๊ฒŒ ์ž”์ฐจ์—๋งŒ ํฌ์ปค์Šค๋ฅผ ๋งž์ถ”์–ด ํ•™์Šตํ•˜๋Š” ๊ฒƒ์ด ๋งค์šฐ ์ด์ƒ์ ์ด๋ผ๊ณ  ๋ณด์ผ์ง€๋„ ๋ชจ๋ฅด๊ฒ ์ง€๋งŒ ๋ชจ๋ธ์ด ๋ณธ์  ์—†๋Š” ๋ฐ์ดํ„ฐ์—๋Š” ์˜ˆ์ธก์„ ์ž˜ ํ•˜์ง€ ๋ชปํ•˜๋Š” ๊ณผ์ ํ•ฉ(Overfitting) ๋ฌธ์ œ๋ฅผ ์œ ๋ฐœํ•  ๊ฐ€๋Šฅ์„ฑ์ด ๋งค์šฐ ๋†’๋‹ค๋Š” ๊ฒƒ์ด ์น˜๋ช…์ ์ธ ๋‹จ์ ์ž…๋‹ˆ๋‹ค.
  • ๋ฒ”์ฃผํ˜• ๋ณ€์ˆ˜์˜ ์˜ˆ์ธก๋ชจ๋ธ์— ์ตœ์ ํ™”๋œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.

  • ๊ธฐ์กด์˜ ๊ทธ๋ž˜๋””์–ธํŠธ ๋ถ€์ŠคํŒ… ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์กฐ์ž‘ํ•˜์—ฌ ํƒ€๊ฒŸ ๋ˆ„์ˆ˜(target leakage)๋ฅผ ๊ฐœ์„ ํ•ฉ๋‹ˆ๋‹ค.

    • target leakage๋Š” ์˜ˆ์ธก ์‹œ์ ์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†๋Š” ๋ฐ์ดํ„ฐ๊ฐ€ ๋ฐ์ดํ„ฐ์…‹์— ํฌํ•จ๋˜๋Š” ์˜ค๋ฅ˜๋ฅผ ๋งํ•ฉ๋‹ˆ๋‹ค.
    • ์ฆ‰, ๋ชจ๋ธ์ด ๋…๋ฆฝ๋ณ€์ˆ˜๋“ค์ธ x๋งŒ์„ ํ™œ์šฉํ•˜์—ฌ ์ข…์†๋ณ€์ˆ˜์ธ y๋ฅผ ์˜ˆ์ธกํ•ด์•ผ ํ•˜๋Š”๋ฐ, y์— ๋Œ€ํ•œ ์ •๋ณด๊ฐ€ x์— ํฌํ•จ๋˜์–ด ์žˆ๋Š” ๊ฒฝ์šฐ๋ฅผ ๋งํ•ฉ๋‹ˆ๋‹ค.
    • ๊ธฐ์กด์˜ ๊ทธ๋ž˜๋””์–ธํŠธ ๋ถ€์ŠคํŒ… ๋ฐฉ๋ฒ•๋“ค์€ ์†์‹คํ•จ์ˆ˜๋ฅผ target value์— ๋Œ€ํ•ด ํŽธ๋ฏธ๋ถ„ํ•œ ๊ทธ๋ž˜๋””์–ธํŠธ ๊ฐ’์„ ํ™œ์šฉํ•˜๊ธฐ์— ์ข‹์€ ์•„์ด๋””์–ด ๊ฐ™์ง€๋งŒ, target value๋ฅผ ํ™œ์šฉํ•˜์—ฌ ์ƒ๊ธฐ๋Š” target leakage๋กœ ์ธํ•ด training/test ๋ฐ์ดํ„ฐ ์…‹์˜ output์˜ ๋ถ„ํฌ์— ์ฐจ์ด๊ฐ€ ์ƒ๊ธฐ๊ฒŒ ๋˜๊ณ  ์˜ค๋ฒ„ํ”ผํŒ…์„ ๋ฐœ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค.
    • ์ฆ‰, ๋‹ค์Œ ์Šคํ…์˜ ํŠธ๋ฆฌ๋ฅผ ๋งŒ๋“ค ๋•Œ, ์ด์ „์— ์‚ฌ์šฉํ–ˆ๋˜ ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค์‹œ ์žฌ์‚ฌ์šฉํ•˜์—ฌ ๊ณผ์ ํ•ฉ์ด ์‰ฝ๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
  • XGBoost, LightGBM์ด Hyper-parameter์— ๋”ฐ๋ผ ์„ฑ๋Šฅ์ด ๋‹ฌ๋ผ์ง€๋Š” ๋ฏผ๊ฐํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๊ฒƒ์—๋„ ์ดˆ์ ์„ ๋งž์ถ”์—ˆ์Šต๋‹ˆ๋‹ค.


[CatBoost์˜ ํŠน์ง•]

1. Level-wise Tree

  • LightGBM์€ DFS(๊นŠ์ด ์šฐ์„  ํƒ์ƒ‰)์ฒ˜๋Ÿผ ํŠธ๋ฆฌ๋ฅผ ์šฐ์„ ์ ์œผ๋กœ ๊นŠ๊ฒŒ ํ˜•์„ฑํ•˜๋Š” ๋ฐฉ์‹์„ ์ทจํ•˜๋ฉฐ, XGBoost๋Š” BFS(๋„ˆ๋น„ ์šฐ์„  ํƒ์ƒ‰)์ฒ˜๋Ÿผ ์šฐ์„ ์ ์œผ๋กœ ๋„“๊ฒŒ ํŠธ๋ฆฌ๋ฅผ ํ˜•์„ฑํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. CatBoost๋„ XGBoost์ฒ˜๋Ÿผ BFS ๋ฐฉ์‹ ์ฆ‰, level-wise ๋ฐฉ์‹์œผ๋กœ ํŠธ๋ฆฌ๋ฅผ ํ˜•์„ฑํ•˜๋‚˜ Feature๋ฅผ ๋ชจ๋‘ ๋™์ผํ•˜๊ฒŒ ๋Œ€์นญ์ ์ธ ํŠธ๋ฆฌ ๊ตฌ์กฐ๋กœ ํ˜•์„ฑํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฐ ๋ฐฉ์‹์„ ํ†ตํ•ด ์˜ˆ์ธก ์‹œ๊ฐ„์„ ๊ฐ์†Œ์‹œํ‚ค๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

2. Ordered Boosting

  • CatBoost๋Š” ๊ธฐ์กด์˜ ๋ถ€์ŠคํŒ… ๊ณผ์ •๊ณผ ์ „์ฒด์ ์ธ ์–‘์ƒ์€ ๋น„์Šทํ•˜๋˜, ์กฐ๊ธˆ ๋‹ค๋ฆ…๋‹ˆ๋‹ค.

  • ๊ธฐ์กด์˜ ๋ถ€์ŠคํŒ… ๋ชจ๋ธ์ด ์ผ๊ด„์ ์œผ๋กœ ๋ชจ๋“  ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋ฅผ ๋Œ€์ƒ์œผ๋กœ ์ž”์ฐจ๊ณ„์‚ฐ์„ ํ–ˆ๋‹ค๋ฉด, CatBoost๋Š” ์ผ๋ถ€ ๋ฐ์ดํ„ฐ๋งŒ์„ ๊ฐ€์ง€๊ณ  ์ž”์ฐจ๊ณ„์‚ฐ์„ ํ•œ ๋’ค, ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด ๋‚˜๋จธ์ง€ ๋ฐ์ดํ„ฐ์˜ ์ž”์ฐจ๋Š” ์ด ๋ชจ๋ธ๋กœ ์˜ˆ์ธกํ•œ ๊ฐ’์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ์‹œ

  • ๋ฐ์ดํ„ฐ ์„ธํŠธ์— 10๊ฐœ์˜ ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ๊ณ  ์•„๋ž˜์™€ ๊ฐ™์ด ์‹œ๊ฐ„ ์ˆœ์„œ๊ฐ€ ์ง€์ •๋˜์–ด ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

  • ๋ฐ์ดํ„ฐ์— ์‹œ๊ฐ„์ด ์—†์œผ๋ฉด CatBoost๋Š” ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์— ๋Œ€ํ•ด ์ธ๊ณต ์‹œ๊ฐ„์„ ๋ฌด์ž‘์œ„๋กœ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.


1. ๋จผ์ € x1์˜ ์ž”์ฐจ๋งŒ ๊ณ„์‚ฐํ•˜๊ณ , ์ด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ชจ๋ธ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด ํ›„, x2์˜ ์ž”์ฐจ๋ฅผ ์ด ๋ชจ๋ธ๋กœ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.
2. x1, x2์˜ ์ž”์ฐจ๋ฅผ ๊ฐ€์ง€๊ณ  ๋ชจ๋ธ์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ์ด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ x3, x4์˜ ์ž”์ฐจ๋ฅผ ๋ชจ๋ธ๋กœ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.
3. x1, x2, x3, x4๋ฅผ ๊ฐ€์ง€๊ณ  ๋ชจ๋ธ์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ์ด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ, x5, x6, x7, x8์˜ ์ž”์ฐจ๋ฅผ ๋ชจ๋ธ๋กœ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.
4. ์œ„ ๊ณผ์ •์„ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.

  • ์ด ๋ฐฉ๋ฒ•์€ target์ด ์•„๋‹Œ ๊ด€์ธก๋œ ๊ธฐ๋ก์—๋งŒ ์˜์กดํ•˜๊ธฐ์— target leakage๋ฅผ ๋ง‰์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

3. Random Permutation

๋‘ ๋ฒˆ์งธ ๋ฐ์ดํ„ฐ์˜ ์ž”์ฐจ(r(x5,y5))๋ฅผ ํ†ตํ•ด์„œ ๋‘ ๋ฒˆ์งธ ๋ฐ์ดํ„ฐ๊ฐ€ ํฌํ•จ๋œ ๋ชจ๋ธ M2๋กœ ์—…๋ฐ์ดํŠธํ•˜๊ณ  ์ด ๋ชจ๋ธ์„ ์ด์šฉํ•ด์„œ ์„ธ๋ฒˆ์งธ ๋ฐ์ดํ„ฐ(x9)์˜ ๋ถ„๋ฅ˜๊ฐ’์„ ์˜ˆ์ธกํ•˜๊ณ  ๊ทธ ์ž”์ฐจ๋ฅผ ๊ตฌํ•˜๋Š” ๊ณผ์ •.

  • Ordered Boosting์„ ํ•  ๋•Œ, ๋ฐ์ดํ„ฐ ์ˆœ์„œ๋ฅผ ์„ž์–ด์ฃผ์ง€ ์•Š์œผ๋ฉด ๋งค ๋ฒˆ ๊ฐ™์€ ์ˆœ์„œ๋Œ€๋กœ ์ž”์ฐจ๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ๋ชจ๋ธ์„ ๋งŒ๋“ค ๊ฐ€๋Šฅ์„ฑ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์ˆœ์„œ๋Š” ์‚ฌ์‹ค ์ž„์˜๋กœ ์ •ํ•œ ๊ฒƒ์ž„์œผ๋กœ ์ˆœ์„œ ์—ญ์‹œ ๋งค ๋ฒˆ ์„ž์–ด์ค˜์•ผ ํ•ฉ๋‹ˆ๋‹ค.

  • CatBoost๋Š” ์ด๋Ÿฌํ•œ ๊ฒƒ ์—ญ์‹œ ๊ฐ์•ˆํ•ด์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์…”ํ”Œ๋งํ•˜์—ฌ ๋ฝ‘์•„๋ƒ…๋‹ˆ๋‹ค. ๋ฝ‘์•„๋‚ผ ๋•Œ๋„ ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋ฅผ ๋ฝ‘๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ, ๊ทธ ์ค‘ ์ผ๋ถ€๋งŒ ๊ฐ€์ ธ์˜ค๊ฒŒ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ชจ๋“  ๊ธฐ๋ฒ•์ด ๋ชจ๋‘ ์˜ค๋ฒ„ํ”ผํŒ… ๋ฐฉ์ง€๋ฅผ ์œ„ํ•ด ํŠธ๋ฆฌ๋ฅผ ๋‹ค๊ฐ์ ์œผ๋กœ ๋งŒ๋“ค๋ ค๋Š” ์‹œ๋„์ž…๋‹ˆ๋‹ค.

  • ๋˜ํ•œ K-fold Cross Validation์ฒ˜๋Ÿผ ์ฃผ์–ด์ง„ ์ „์ฒด ๋ฐ์ดํ„ฐ๋ฅผ ์ž„์˜์ ์œผ๋กœ N๊ฐœ์˜ Fold๋กœ ๋‚˜๋ˆ„์–ด์„œ ๊ฐ Fold์— ์†ํ•œ ๋ฐ์ดํ„ฐ์…‹๋“ค์— Ordered Boosting์„ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


4. Ordered Target Encoding

  • Target Encoding, Mean Encoding, Response Encoding ์ด๋ผ๊ณ  ๋ถˆ๋ฆฌ์šฐ๋Š” ๊ธฐ๋ฒ• (3๊ฐœ ๋‹ค ๊ฐ™์€ ๋ง)์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ์‹œ

  • ์œ„ ๋ฐ์ดํ„ฐ์—์„œ time, feature1์œผ๋กœ class_label์„ ์˜ˆ์ธกํ•ด์•ผํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด๋ด…์‹œ๋‹ค. feature1์˜ cloudy๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ธ์ฝ”๋”ฉ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
cloudy = (15 +14 +20 + 25)/4 = 18.5
  • ์ฆ‰, cloudy๋ฅผ cloudy์— ํ•ด๋‹นํ•˜๋Š” ๋ฐ์ดํ„ฐ๋“ค์˜ class_label์˜ ๊ฐ’์˜ ํ‰๊ท ์œผ๋กœ ์ธ์ฝ”๋”ฉ ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ๋•Œ๋ฌธ์— Mean encoding์ด๋ผ๊ณ  ๋ถˆ๋ฆฌ๊ธฐ๋„ ํ•ฉ๋‹ˆ๋‹ค.

  • ๊ทธ๋Ÿฐ๋ฐ ์œ„๋Š” ์šฐ๋ฆฌ๊ฐ€ ์˜ˆ์ธกํ•ด์•ผํ•˜๋Š” ๊ฐ’์ด ํ›ˆ๋ จ ์…‹ ํ”ผ์ฒ˜์— ๋“ค์–ด๊ฐ€๋ฒ„๋ฆฌ๋Š” ๋ฌธ์ œ, ์ฆ‰ target leakage ๋ฌธ์ œ๋ฅผ ์ผ์œผํ‚ต๋‹ˆ๋‹ค. ์ด๋Š” ์˜ค๋ฒ„ํ”ผํŒ…์„ ์ผ์œผํ‚ค๋Š” ์ฃผ ์›์ธ์ด์ž, Mean encoding ๋ฐฉ๋ฒ• ์ž์ฒด์˜ ๋ฌธ์ œ์ด๊ธฐ๋„ ํ•ฉ๋‹ˆ๋‹ค.

  • ๊ทธ๋ž˜์„œ, Catboost๋Š” ์ด์— ๋Œ€ํ•œ ํ•ด๊ฒฐ์ฑ…์œผ๋กœ, ํ˜„์žฌ ๋ฐ์ดํ„ฐ์˜ ์ธ์ฝ”๋”ฉ์„ ํ•˜๊ธฐ ์œ„ํ•ด ์•„๋ž˜์ฒ˜๋Ÿผ ์ด์ „ ๋ฐ์ดํ„ฐ๋“ค์˜ ์ธ์ฝ”๋”ฉ๋œ ๊ฐ’์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

- Friday ์—๋Š”, cloudy = (15+14)/2 
## (tues + wed) / 2
- Saturday ์—๋Š”, cloudy = (15+14+20)/3 = 16.3 ๋กœ ์ธ์ฝ”๋”ฉ ๋œ๋‹ค.
## (tues + wed + fri) / 3

Tues์˜ ๊ฒฝ์šฐ Ordered Target Encoding์„ ์ ์šฉํ•˜๋ฉด ์–ด๋–ป๊ฒŒ ๋ ๊นŒ์š”? ์›๋ž˜๋Œ€๋กœ๋ผ๋ฉด 0์„ ๋„ฃ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.(๊ณผ๊ฑฐ์˜ ๋ฐ์ดํ„ฐ๊ฐ€ 0๊ฐœ์ด๊ธฐ ๋•Œ๋ฌธ!) ํ•˜์ง€๋งŒ ์ด ๊ฒฝ์šฐ Laplace Smoothing ์ด๋ผ๋Š” ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ์ฆ‰, ํ˜„์žฌ ๋ฐ์ดํ„ฐ์˜ target value๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ , ์ด์ „ ๋ฐ์ดํ„ฐ๋“ค์˜ target value๋งŒ์„ ์‚ฌ์šฉํ•˜๋‹ค๋ณด๋‹ˆ, target leakage๊ฐ€ ์ผ์–ด๋‚˜์ง€ ์•Š๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

  • ๋ฒ”์ฃผํ˜• ๋ณ€์ˆ˜๋ฅผ ์ˆ˜๋กœ ์ธ์ฝ”๋”ฉํ•˜๋Š” ํ•  ๋•Œ, ์˜ค๋ฒ„ํ”ผํŒ…๋„ ๋ง‰๊ณ  ์ˆ˜์น˜๊ฐ’์˜ ๋‹ค์–‘์„ฑ๋„ ๋งŒ๋“ค์–ด ์ฃผ๋Š” ์ฐธ ์˜๋ฆฌํ•œ ๊ธฐ๋ฒ•์ด ์•„๋‹ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.


5. Categorical Feauture Combinations

  • ๋ฐ์ดํ„ฐ์…‹์˜ ํด๋ž˜์Šค๋ฅผ ๋ช…ํ™•ํ•˜๊ฒŒ ๊ตฌ๋ถ„ํ•  ์ˆ˜ ์žˆ๋Š” ์ค‘๋ณต๋˜๋Š” ๋ฒ”์ฃผํ˜• ๋ณ€์ˆ˜๊ฐ€ 2๊ฐœ ์ด์ƒ ์กด์žฌํ•  ๋•Œ, ์ด๋ฅผ ํ•˜๋‚˜์˜ ๋ณ€์ˆ˜๋กœ ํ†ตํ•ฉํ•ด ์ฒ˜๋ฆฌํ•˜๋Š” ๊ฒƒ์„ ์ž๋™์œผ๋กœ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ์‹œ

  • country๋ณ€์ˆ˜๋งŒ ๋ด๋„ hair_color ๋ณ€์ˆ˜๊ฐ€ ๊ฒฐ์ •๋˜๊ธฐ ๋•Œ๋ฌธ์—, class_label์„ ์˜ˆ์ธกํ•˜๋Š”๋ฐ ์žˆ์–ด, ๋‘ ๋ณ€์ˆ˜ ๋ชจ๋‘ ํ•„์š”์—†์ด ์ด ์ค‘ ํ•˜๋‚˜์˜ ๋ณ€์ˆ˜๋งŒ ์žˆ์–ด๋„ ๋ฉ๋‹ˆ๋‹ค. CatBoost๋Š” ์ด๋ ‡๊ฒŒ information gain์ด ๋™์ผํ•œ ๋‘ ๋ณ€์ˆ˜๋ฅผ ํ•˜๋‚˜์˜ ๋ณ€์ˆ˜๋กœ ๋ฌถ์–ด๋ฒ„๋ฆฝ๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•จ์œผ๋กœ์จ ๋ณ€์ˆ˜๊ฐ€ ๋Š˜์–ด๋‚˜๋Š” ๋ฌธ์ œ๋ฅผ ์˜ˆ๋ฐฉํ•˜์—ฌ ์—ฐ์‚ฐ ๋น„์šฉ์„ ์ค„์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

6. One-hot Encoding

  • ์‚ฌ์‹ค ๋ฒ”์ฃผํ˜• ๋ณ€์ˆ˜๋ฅผ ํ•ญ์ƒ Target Encodingํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹™๋‹ˆ๋‹ค.

  • CatBoost๋Š” ๋‚ฎ์€ Cardinality๋ฅผ ๊ฐ€์ง€๋Š” ๋ฒ”์ฃผํ˜• ๋ณ€์ˆ˜์— ํ•œํ•ด์„œ, ๊ธฐ๋ณธ์ ์œผ๋กœ Onr-hot Encoding์„ ์‹œํ–‰ํ•ฉ๋‹ˆ๋‹ค.

  • ์•„๋ฌด๋ž˜๋„ Low Cardinality ๋ฅผ ๊ฐ€์ง€๋Š” ๋ฒ”์ฃผํ˜• ๋ณ€์ˆ˜์˜ ๊ฒฝ์šฐ Target Encoding ๋ณด๋‹ค One-hot ์ด ๋” ํšจ์œจ์ ์ด๋ผ ๊ทธ๋Ÿฐ ๋“ฏ ํ•ฉ๋‹ˆ๋‹ค.

  • Python์—์„œ CatBoost๋กœ ๋ชจ๋ธ์„ ์ƒ์„ฑํ•  ๋•Œ, Hyper-parameter๋กœ์„œ one_hot_max_size = N๋ผ๋Š” ๊ฐ’์„ ๋”ฐ๋กœ ์„ค์ •ํ•ด์ฃผ๋ฉด ๋ฒ”์ฃผํ˜• ๋ณ€์ˆ˜๋“ค ์ค‘ ๊ฐ’์˜ level ๊ฐœ์ˆ˜๊ฐ€ N๊ฐœ์ˆ˜๋ณด๋‹ค ์ž‘๋‹ค๋ฉด ํ•ด๋‹น ๋ฒ”์ฃผํ˜• ๋ณ€์ˆ˜๋ฅผ ์ž๋™์œผ๋กœ One-Hot Encoding์œผ๋กœ ์ฒ˜๋ฆฌํ•ด์ค๋‹ˆ๋‹ค.


7. ์ˆ˜์น˜ํ˜• ๋ณ€์ˆ˜ ์ฒ˜๋ฆฌ

  • ์ผ๋ฐ˜ Tree-based ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ฒ˜๋Ÿผ Information Gain์„ ํ™œ์šฉํ•ด์„œ ์ตœ์ ์˜ Split ๊ธฐ์ค€์„ ๊ณ„์‚ฐํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

8. Optimized Parameter tuning

  • Catboost ๋Š” ๊ธฐ๋ณธ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๊ธฐ๋ณธ์ ์œผ๋กœ ์ตœ์ ํ™”๊ฐ€ ์ž˜ ๋˜์–ด์žˆ์–ด์„œ, ํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹์— ํฌ๊ฒŒ ์‹ ๊ฒฝ์“ฐ์ง€ ์•Š์•„๋„ ๋ฉ๋‹ˆ๋‹ค. (๋ฐ˜๋ฉด XGBoost๋‚˜ LightGBM์€ ํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹์— ๋งค์šฐ ๋ฏผ๊ฐ)

  • ์‚ฌ์‹ค ๋Œ€๋ถ€๋ถ„ ๋ถ€์ŠคํŒ… ๋ชจ๋ธ๋“ค์ด ํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹ํ•˜๋Š” ์ด์œ ๋Š”, ํŠธ๋ฆฌ์˜ ๋‹คํ˜•์„ฑ๊ณผ ์˜ค๋ฒ„ํ”ผํŒ… ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•จ์ธ๋ฐ, CatBoost๋Š” ์ด๋ฅผ ๋‚ด๋ถ€์ ์ธ ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ ํ•ด๊ฒฐํ•˜๊ณ  ์žˆ์œผ๋‹ˆ, ๊ตณ์ด ํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹ํ•  ํ•„์š”๊ฐ€ ์—†๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

  • ๊ตณ์ด ํ•œ๋‹ค๋ฉด learning_rate, random_strength, L2_regulariser๊ณผ ๊ฐ™์€ ํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹์ธ๋ฐ, ๊ฒฐ๊ณผ๋Š” ํฐ ์ฐจ์ด๊ฐ€ ์—†๋‹ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.


[Catboost์˜ ์žฅ๋‹จ์ ]

1. ์žฅ์ 

  • ๋‹ค๋ฅธ GBM์— ๋น„ํ•ด overfitting์ด ์ ๋‹ค.

  • ๋ฒ”์ฃผํ˜• ๋ณ€์ˆ˜์— ๋Œ€ํ•ด ํŠน์ • ์ธ์ฝ”๋”ฉ ๋ฐฉ์‹์œผ๋กœ ์ธํ•˜์—ฌ ๋ชจ๋ธ์˜ ์ •ํ™•๋„์™€ ์†๋„๊ฐ€ ๋†’์Šต๋‹ˆ๋‹ค.

  • One-hot Encoding, Label Encoding ๋“ฑ encoding ์ž‘์—…์„ ํ•˜์ง€ ์•Š๊ณ ๋„ ๊ทธ๋Œ€๋กœ ๋ชจ๋ธ์˜ input์œผ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


2. ๋‹จ์ 

  • missing data๋ฅผ ์ฒ˜๋ฆฌํ•ด์ฃผ์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

  • Sparse Matrix ์ฆ‰, ๊ฒฐ์ธก์น˜๊ฐ€ ๋งค์šฐ ๋งŽ์€ ๋ฐ์ดํ„ฐ์…‹์—๋Š” ๋ถ€์ ํ•ฉํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.

    • ์˜ˆ๋ฅผ ๋“ค์–ด, ์ถ”์ฒœ์‹œ์Šคํ…œ์— ์ž์ฃผ ์‚ฌ์šฉ๋˜๋Š” ์‚ฌ์šฉ์ž-์•„์ดํ…œ ํ–‰๋ ฌ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ดํŽด๋ณด๋ฉด ๋ณดํ†ต Sparseํ•œ ํ˜•ํƒœ๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค. ๋งŒ์•ฝ ์ด๋Ÿฌํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ํ™œ์šฉํ•˜๋ ค๋ฉด Sparseํ•œ ํŠน์„ฑ์ด ์—†๋„๋ก Embedding์„ ์ ์šฉํ•œ๋‹ค๋˜์ง€ ๋“ฑ ๋ฐ์ดํ„ฐ๋ฅผ ๋ณ€ํ˜•ํ•œ ํ›„ CatBoost์— ํ™œ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ ํ•ฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.
  • ๋ฐ์ดํ„ฐ ๋Œ€๋ถ€๋ถ„์ด ์ˆ˜์น˜ํ˜• ๋ณ€์ˆ˜์ธ ๊ฒฝ์šฐ, LightGBM๋ณด๋‹ค ํ•™์Šต ์†๋„๊ฐ€ ๋Š๋ฆฝ๋‹ˆ๋‹ค. (์ฆ‰ ๋Œ€๋ถ€๋ถ„์ด ๋ฒ”์ฃผํ˜• ๋ณ€์ˆ˜์ธ ๊ฒฝ์šฐ ์‚ฌ์šฉ)


[์‹ค์Šต]

> # [์‚ฌ์ „ ์ž‘์—…]
> #install.packages('devtools')
> library(devtools)
> devtools::install_url('https://github.com/catboost/catboost/releases/download/v0.20/catboost-R-Windows-0.20.tgz', INSTALL_opts = c("--no-multiarch", "--no-test-load"))
> library(catboost)
> library(caret)


> credit <- read.csv('D:\\ADP\\credit_final.csv')
> class(credit$credit.rating) 
> credit$credit.rating <- as.factor(credit$credit.rating)
> levels(credit$credit.rating) <- c("pos", "neg")
> class(credit$credit.rating)


> credit$foreign.worker <- as.factor(credit$foreign.worker)
> credit$telephone <- as.factor(credit$telephone)
> credit$other.credits <- as.factor(credit$other.credits)
> credit$bank.credits <- as.factor(credit$bank.credits)
> credit$account.balance <- as.factor(credit$account.balance)

# [๋ฐ์ดํ„ฐ ๋ถ„ํ• ]
> idx <- createDataPartition(y = credit$credit.rating,
+                            time = 1, p = 0.7, list = FALSE)
> train <- credit[idx,]
> test <- credit[-idx,]
> table(train$credit.rating)

pos neg 
210 490 


> # [train set ์—…์ƒ˜ํ”Œ๋ง]
> train <- upSample(subset(train, select = -credit.rating), train$credit.rating)
> train <- train %>% rename('credit.rating' = 'Class')
> table(train$credit.rating)

pos neg 
490 490 

> # [k-fold ์ƒ์„ฑ ๋ฐ ๋žœ๋ค ๊ทธ๋ฆฌ๋“œ์„œ์น˜ ์„ค์ •]
> cv_folds <- createMultiFolds(train$credit.rating, k = 3, times = 3)
> fit_ctrl <- trainControl(method = "repeatedcv", 
+                          number = 3,
+                          repeats = 3,
+                          index = cv_folds,
+                          search = 'random',
+                          verboseIter = TRUE)

> # [๋ชจ๋ธ ํ•™์Šต]
> train_x <- subset(train, select = c(-credit.rating))
> train_y <- train$credit.rating
> catboost_model <- train(x = train_x, y = train_y,
+                      method = catboost.caret,
+                      metric = 'Accuracy',
+                      preProcess = c("zv", "center", "scale", "spatialSign"),
+                      #tuneGrid = grid,
+                      trControl = fit_ctrl,
+                      maximize = TRUE,
+                      tuneLength = 10)
...
92:	learn: 0.0003169	total: 430ms	remaining: 32.4ms
93:	learn: 0.0003046	total: 439ms	remaining: 28ms
94:	learn: 0.0002871	total: 449ms	remaining: 23.6ms
95:	learn: 0.0002675	total: 455ms	remaining: 19ms
96:	learn: 0.0002578	total: 461ms	remaining: 14.3ms
97:	learn: 0.0002528	total: 466ms	remaining: 9.51ms
98:	learn: 0.0002379	total: 470ms	remaining: 4.75ms
99:	learn: 0.0002233	total: 476ms	remaining: 0us

> # [์ตœ์  ๋ชจ๋ธ ํ™•์ธ]
> catboost_model
Catboost 

980 samples
 20 predictor
  2 classes: 'pos', 'neg' 

Pre-processing: centered (15), scaled (15), spatial sign
 transformation (15), ignore (5) 
Resampling: Cross-Validated (3 fold, repeated 3 times) 
Summary of sample sizes: 654, 652, 654, 654, 654, 652, ... 
Resampling results across tuning parameters:

  depth  learning_rate  l2_leaf_reg  rsm  border_count  Accuracy 
  2      0.1158270      1e-06        0.8  168           0.7721021
  2      0.3318660      1e-06        0.8   88           0.7881399
  2      0.7353266      1e-06        0.8  252           0.8054600
  4      0.0971376      1e-06        0.9   79           0.8173620
  5      0.4965535      1e-06        0.7  214           0.8064741
  5      0.6898669      1e-06        0.7  186           0.7969703
  8      0.1959387      1e-01        1.0   39           0.8418541
  8      0.4183411      1e-01        0.8  151           0.8238254
  9      0.5728624      1e-06        0.8   59           0.8139413
  9      0.7598337      1e-06        0.7  143           0.8115949
  Kappa    
  0.5442042
  0.5762798
  0.6109199
  0.6347241
  0.6129483
  0.5939407
  0.6837082
  0.6476508
  0.6278825
  0.6231899

Tuning parameter 'iterations' was held constant at a value of 100
Accuracy was used to select the optimal model using the largest value.
The final values used for the model were depth = 8, learning_rate
 = 0.1959387, iterations = 100, l2_leaf_reg = 0.1, rsm = 1
 and border_count = 39.

> catboost_model$bestTune
  depth learning_rate iterations l2_leaf_reg rsm border_count
7     8     0.1959387        100         0.1   1           39
> ## border_count : ์ˆ˜์น˜ํ˜• ๋ณ€์ˆ˜ ๋ถ„ํ•  ์ˆ˜
> ## rsm : ๋ณ€์ˆ˜๊ฐ€ ๋ฌด์ž‘์œ„๋กœ ๋‹ค์‹œ ์„ ํƒ๋  ๋•Œ ๊ฐ ๋ถ„ํ•  ์„ ํƒ์—์„œ ์‚ฌ์šฉํ•  ๋ณ€์ˆ˜์˜ ๋ฐฑ๋ถ„์œจ


> # [๋ณ€์ˆ˜ ์ค‘์š”๋„ ํ™•์ธ]
> plot(varImp(catboost_model))


> # [์˜ˆ์ธก๊ฐ’์— ๋Œ€ํ•œ ํ˜ผ๋ˆํ–‰๋ ฌ]
> confusionMatrix(test$credit.rating, predict(catboost_model, test, type = 'raw'))
Confusion Matrix and Statistics

          Reference
Prediction pos neg
       pos  44  46
       neg  24 186
                                          
               Accuracy : 0.7667          
                 95% CI : (0.7146, 0.8134)
    No Information Rate : 0.7733          
    P-Value [Acc > NIR] : 0.63898         
                                          
                  Kappa : 0.4027          
                                          
 Mcnemar's Test P-Value : 0.01207         
                                          
            Sensitivity : 0.6471          
            Specificity : 0.8017          
         Pos Pred Value : 0.4889          
         Neg Pred Value : 0.8857          
             Prevalence : 0.2267          
         Detection Rate : 0.1467          
   Detection Prevalence : 0.3000          
      Balanced Accuracy : 0.7244          
                                          
       'Positive' Class : pos 
profile
ํ•จ๊ป˜ ๊ณต๋ถ€ํ•ด์š”!

0๊ฐœ์˜ ๋Œ“๊ธ€