Ada Boost 심층 이해

박재한·2022년 1월 3일
0

Machine Learning

목록 보기
1/6

참고 사이트

1. Ada Boost 개요

Ada Boost(Adaptive boost)는 Boost 기법 중 가장 기본이 되는 것으로 다른 발전된 Boost 기법들의 이해의 기반이 된다.
Ada Boost에는 stump라는 결정(decision)의 단위가 되는 트리(depth가 1인 decision tree로 보면된다.)가 있는데 이 stump는 depth가 1이라 오로지 하나의 feature만 가지고 판단한다. 그리고 이러한 stump 여러 개가 내린 결정들을 모아서 투표(voting)를 하여 다수결의 결정을 따르는 방식이 Ada Boost이다.
아래와 같이 노드 하나에 두개의 리프(leaf)를 지닌 트리를 stump라고 한다.

AdaBoost는 아래와 같이 여러 개의 stump로 구성이 된다. 이를 Forest of stumps라고 한다.

트리와 다르게 stump는 정확한 분류를 하지 못한다. 여러 질문을 통해 데이터를 분류하는 트리와 다르게, stump는 단 하나의 질문으로 데이터를 분류해야하기 때문이다.(depth가 1인 decision tree라서 질문이 하나만 있을 수 밖에 없다.) 따라서 stump는 약한 학습기(weak learner)이다.
예를 들어 '심장병 발생한다' 여부를 판단하는 모델을 만들 때, 관련 데이터의 feature가 '가슴 통증', '혈관 막힘', '환자의 체중' 이렇게 3가지 feature가 있을 때 '가슴 통증'만으로 판단하는 tree, '혈관 막힘'만으로 판단하는 tree, '환자의 체중'만으로 판단하는 tree로 stump가 구성이 되어 질 수 있다. 특히 '환자의 체중' feature는 범주형이 아닌 연속형 변수이기 때문에 tree를 구성하는 질문이 여러개 있을 수 있어(예를 들어 체중이 80Kg이상이냐, 70kg이상이냐 등등) 하나의 feature에 다수개의 tree가 나올 수 있다. 여기서 혈관이 막히고 심장병이 있고, 체중이 90kg인 데이터가 주어진다면 이 데이터로 적어도 앞에서의 3개이상의 tree를 거쳐가면서 판단하게 되는데 각각의 tree(stump, weak learner)마다 '심장병이 있다'라고 판단하는 것도 있고 '심장병이 없다' 라고 판단하는 것도 있을 것이다. 이렇게 판단에 참여한 다수개의 tree의 판단의 결과를 취합하여 가장 많은 판단을 내린 결과를 투표(voting)을 통해서 선정하여 최종적으로 심장병 있다 없다의 여부를 결정하는 것이다. 이것이 바로 Ada Boost의 기본적인 개념이다.
실제로는 각각의 tree의 예측 결과에 가중치(weight)를 부여하여(가중치 부여 방법은 이 다음에 바로 다룬다.) 모든 tree의 예측 결과를 대등하게 보지 않고 가중치가 높은 결정에 좀 더 비중을 두는 식으로 결정한다.
AdaBoost에서는 특정 stump가 다른 stump보다 더 중요한 즉 가중치가 더 높다고 했는데, 아래의 그림에서 보는 것처럼 크기가 큰 것은 가중치가 더 높은 stump를 뜻하고 여기서 가중치가 높다는 것을 Amount of Say가 높다고 표현합니다. 결과에 미치는 영향이 크다는 것을 뜻한다.(아래 다시 설명하겠지만 가중치는 Amount of Say값으로 결정이 되는데 Amount of Say값은 오류가 적을 수록 양의 값으로 급격히 크지고 오류가 많을 수록 음의 값으로 급격히 크진다. 즉 오류가 적을 수록 결정에 대한 권위(Authority)가 더 높아지는 셈이다.)

또한 첫번째 stump의 판단에서 발생한 오류(error)는 두 번째 stump의 판단에 영향을 준다.(두 번째 stump의 Amount of Say 결정에 첫번째 stump의 가중치를 이용한다. 자세한 것은 다음 동작 과정 설명을 봐야 하지 한마디로 설명은 아주 곤란하다.) 두 번째 stump에서 발생한 오류(error)역시 세번째 stump 결과에 영향을 준다. 그렇게 마지막 stump까지 줄줄이 영향을 준다.
참고로 stump의 갯수는 hyper parameter로 사용자가 사전에 정해줘야 한다.(사이킷런의 AdaBoostClassifier는 default로 50개이다.)

2. Ada Boost 작동 원리

다음 심장병 발생에 관한 예제를 가지고 작동 원리를 설명하고자 한다. (참고로 작동원리는 'Ada Boost 바닥부터 구현해 보기'를 참조해서 관련 소스를 하나하나 직접 구현해 가면서 테스트 해보고 충분히 이해한 다음에 작성하는 것이다.)

Chest Pain, Blocked Arteries, Patient Weight에 따른 Heart Disease 여부에 대한 데이터이다. 맨 처음 Sample Weight는 처음에는 8개의 데이터 모두 동일하게 1/(total number of samples) = 1/8이다. 모든 sample의 가중치가 1/8로 동일하다. 이제 각각의 feature가 target value(여기서는 Heart Disease)에 미치는 영향에 대해 살펴보자.

Chest Pain과 Heart Disease와의 관계이다.

단순하게 Chest Pain이 Yes이면 Heart Disease도 Yes라고 판단하는 모델이다.(Chest Pain이 No이면 Heart Disease가 Yes라고 판단하는 stump도 있을 수 있는 데 이것은 오류가 상대적으로 훨씬 더 심해서 배제하였다.) 총 8개의 데이터 중 Chest Pain이 Yes인 데이터(즉, Heart Disease를 Yes라고 판단한 데이터)는 5개, No인 데이터는 3개이다. Heart Disease를 Yes라고 판단한 것 중 올바르게 판단한 것은 3개, 틀리게 판단한 것은 2개이다. 반대로, Heart Disease를 No라고 판단한 것 중 올바르게 판단한 것은 2개, 틀리게 판단한 것은 1개이다. 따라서 위와 같이 구분이 되었다.

이젠, Blocked Arteries와 Heart Disease와의 관계이다.

다음은 마지막으로 Patient Weight와 Heart Disease와의 관계이다.

각 Stump의 지니 계수를 구한다.
지니 계수를 구하는 방법은 지니 계수(Gini Index)란?을 참조하였다.
이것도 참조하면 도움이 된다.

  • Chest Pain
    1(35)2(25)2=0.481-(\frac{3}{5})^2-(\frac{2}{5})^2=0.48
    1(23)2(13)2=0.44444444441-(\frac{2}{3})^2-(\frac{1}{3})^2=0.4444444444
    (58)0.48+(38)0.444=0.466665(\frac{5}{8})*0.48+(\frac{3}{8})*0.444=0.466665

  • Blocked Asteries
    1(36)2(36)2=0.51-(\frac{3}{6})^2-(\frac{3}{6})^2=0.5
    1(12)2(12)2=0.51-(\frac{1}{2})^2-(\frac{1}{2})^2=0.5
    (68)0.5+(28)0.5=0.5(\frac{6}{8})*0.5+(\frac{2}{8})*0.5=0.5

  • Patient Weight
    1(33)2(03)2=01-(\frac{3}{3})^2-(\frac{0}{3})^2=0
    1(45)2(15)2=0.321-(\frac{4}{5})^2-(\frac{1}{5})^2=0.32
    (38)0+(58)0.32=0.2(\frac{3}{8})*0+(\frac{5}{8})*0.32=0.2

마지막 Stump(Patient Weight)의 지니 계수가 가장 작기 때문에 forest의 첫 Stump로 지정한다. 이 Stump가 최종 결과 예측에 있어 얼마만큼의 중요도가 있는지(얼마만큼의 가중치가 있는지) 보겠다.

2.1 Amount of Say 구하기

Heart Disease stump에서 틀리게 분류한 것이 Yes인데 No로 분류한 1개밖에 없다. 따라서 Total Error = 1/8이다.(176보다 크면 Yes이고 작으면 No라서 167라 No라고 판단했는데 실제로는 Yes이다.)

모든 Sample Weights의 합은 1이기 때문에, Total Error는 0과 1 사이의 값을 갖는다. 이 Total Error가 Amount of Say를 결정한다. Amount of Say는 최종 분류에 있어서 해당 Stump가 얼마만큼의 영향을 주는가를 뜻한다.(해당 stump의 결정에 있어서의 무게감, 권위, 가중치이다.) Amount of Say를 구하는 공식은 아래와 같다.
AmountofSay=12ln(1TotalErrorTotalError)Amount\,of\,Say=\frac{1}{2}ln(\frac{1-Total\,Error}{Total\,Error})
Amount of Say를 그래프로 그려보면 아래와 같다. X축은 Total Error, Y축은 Amount of Say이다. Total Error가 0이면 Amount of Say는 굉장히 큰 양수이고, Total Error가 1이면 Amount of Say는 굉장히 작은 음수가 된다. 따라서 Total Error가 0이면 항상 올바른 분류를 한다는 뜻이고, 1이면 항상 반대로 분류를 한다는 뜻이다. Total Error가 0.5일 때는 Amount of Say가 0이다. 동전을 던지는 것과 마찬가지로 의미가 없다는 뜻이다.

다시 stump로 돌아와서 Total Error가 1/8이라고 했으므로,
AmountofSay=12ln(11/81/8)=0.97Amount\,of\,Say=\frac{1}{2}ln(\frac{1-1/8}{1/8})=0.97
Amount of Say는 0.97이다.
그래프 상에서 표현해보면, Total Error가 1/8이고, Amount of Say = 0.97인 아래 지점이다.

2.2 Sample 가중치(weight) 설정

Adaboost에서는 하나의 Stump가 잘못 분류한 sample에 대해서는 다음 Stump로 넘겨줄 때 가중치를 더 높여서 넘겨준다. 그래야 다음 Stump에서 해당 Sample에 더 집중해서 올바로 분류해주기 때문이다. 맨 처음 Weight Stump에서는 아래 빨간 네모 안에 있는 sample만 잘못 분류했다. 따라서 해당 sample의 weight를 1/8보다 크게 하고, 나머지 sample의 weight는 1/8보다 작게 해서 다음 Stump로 넘겨준다. 다음 Stump로 넘겨줄 때의 새로운 sample weight를 구하는 공식은 아래와 같다.

  • Stump에서 잘못 분류한 sample일 경우(weight 높인다)

    NewSampleWeight=(1/8)e0.97=(1/8)2.64=0.33New\,Sample\,Weight = (1/8) * e^{0.97} = (1/8) * 2.64 = 0.33 이다. 기존의 sample weight = 1/8 = 0.125였는데 이보다 더 높아졌다.
  • Stump에서 잘 분류한 sample일 경우(weight 낮춘다)

    amount of say에 - 부호만 붙이면 된다.
    NewSampleWeight=(1/8)e0.97=(1/8)0.38=0.05New\,Sample\,Weight = (1/8) * e^{-0.97} = (1/8) * 0.38 = 0.05 입니다. 기존의 weight인 0.125보다 더 작아졌다.

다시 설명하자면, 이전 Stump에서 잘못 분류된 sample의 경우 sample weight를 증가시켜주고, 이전 Stump에서 잘 분류된 sample의 경우 sample weight를 감소시켜준다. 그래야 다음 Stump에서 이전 Stump에서 잘못 분류한 것에 더 집중을 해서 올바른 분류를 해주기 때문이다. 새로 구한 sample weight는 아래와 같다.

New Sample Weight는 공식에 의해 구한 weight이다. 단, New Sample Weight를 다 더했을 때 값은 0.68로 1이 되지 않는다. Sample Weight의 합은 항상 1이 되어야하므로 오른쪽 Weight처럼 정규화(Normalize)시켜준다. 테이블의 맨 오른쪽 Norm. Weight이다.
이제 기존의 Weight는 모두 지우고 정규화된 새로운 Weight만 보겠다.

그 다음 스텝은 샘플링을 통해 새로운 테이블을 만들어 주는 것이다. 0부터 1까지의 숫자를 랜덤하게 뽑는다. 이때 0~0.07 사이의 숫자가 나오면 첫번째 sample을 선택한다. 0.07~0.14 사이의 숫자가 나오면 두번째 sample을 선택한다. 0.14~0.21 사이의 숫자가 나오면 세번째 sample을 선택한다. 0.21~0.70 사이의 숫자가 나오면 네번째 sample을 선택한다. 눈치 채셨겠지만 sample weight의 누적 숫자에 해당하는 sample을 뽑는 것이다. 그렇게 원래 테이블의 sample수와 똑같은 sample 수를 가진 새로운 테이블을 구성한다. 뽑힌 sample들을 보니 중복되는 것도 있다. 원래 테이블에서 sample weight가 0.49인 sample이 4번이나 뽑혔다. 당연히 0.21 ~ 0.70 사이의 숫자가 나오면 해당 sample을 뽑기 때문에 확률이 많을 것이다.

이제 원래의 테이블은 지우고, 샘플링한 새로운 테이블을 가져온다. 모든 sample의 weight는 다시 1/8로 통일시켜준다. 첫 Stump에서 잘못 분류했던 sample이 4번이나 포함된다고 했으니 sample weight는 1/8로 동일하더라도 똑같은 데이터가 4개가 있어서 실제로는 4/8의 weight를 갖는 것이다. 이는 처음에 잘못 분류를 했기 때문에, 그 다음에는 weight를 높여서 제대로 분류하기 위함이다. weight가 높아지니 해당 sample에 가중치를 더 두고 분류를 할 것이다.

그런데 실제 구현해 보니 weighted된 new sample weight 적용 후 sample weight를 1/8로 다시 초기화 시키지 않고 그대로 다음 stump로 넘기는 것이 결과적으로 분류 성능이 더 좋았다.

2.3 두번째 round

첫번째 round에서 stump를 선정하고, Total Error를 구하고, Amount of Say를 구하였다. 첫번째 round의 Amount of Say를 가지고 sample weight를 업데이트 한다. 업데이트 된 sample weight를 기준으로 학습 데이터에서 sample을 다시 뽑아서 gini 계수가 가장 낮은(분류가 상대적으로 가장 잘된) stump를 선정한다.
두 번째 round부터는 참고 자료가 아닌 Ada Boost 직접 구현한 소스로부터 결과를 받아서 그것을 바탕으로 정리할 것이다.

sample weight
첫번째 round의 Amount of Say를 이용해서 구한 sample weight는 다음과 같으며 두 번째 round에서 학습할 sample data 선택 및 stump(classifier, weak learner) 선택에 영향을 미친다.

구분#1#2#3#4#5#6#7#8
S.W.0.047245720.047245720.047245720.330717780.047245720.047245720.047245720.04724572
S.W. Norm.0.071428820.071428820.071428820.499998290.071428820.071428820.071428820.07142882

stump
업데이트 된 sample weight가 반영된 sample로 다시 뽑아서 결정된 stump는 다음과 같다.

Decision Tree를 잘 보면 'Patient Weight' > 161.5 이면 심장병 발생이 Yes로 예측하고 161.5 이하이면 심장병 발생이 No로 예측한다.('class =' 부분을 보면 된다)
각 노드에서 value는 왼쪽이 -1로 심장병 발생 No에 해당하고 오른쪽이 1로 심장병 발생 Yes에 해당한다. root node에서는 오른쪽이 실제로 Yes인 sample의 가중치의 합(0.07+0.07+0.07+0.4999)인 0.714이고 왼쪽은 실제로 No인 sample의 가중치의 합(0.07*4)인 0.268이다.
오른쪽 child node에서는 Yes라고 예측한 샘플이 6개이고(samples=6), 이 중 실제 Yes로 맞게 분류한 예측의 가중치의 합은 0.714(4개), 실제 No인데 Yes로 잘못 예측한 것의 가중치의 합은 0.143(2개)이다.
다음 왼쪽 child node에서는 No라고 예측한 샘플이 2개이고(samples=2), 이 중 실제 No로 맞게 분류한 예측의 가중치의 합은 0.143(2개), 실제 Yes인데 No로 잘못 예측한 것의 가중치의 합은 0.0(0개)이다.

예측값

y_predy_pred_AOS
10.89587482
10.89587482
10.89587482
10.89587482
-1-0.89587482
-1-0.89587482
10.89587482
10.89587482

Amount of Error
8개중 2개가 오류이므로 Total Error는 예측이 오류인 샘플의 가중치의 합인 0.143이다. 이 때 Amount of Say값은 0.89587482이다.
이 Amount of Say를 예측값에 반영한다.(예측값에 곱한다) 즉 같은 예측이라도 stump의 Amount of Say(발언의 권위, 신뢰도 정도로 이해하자)에 따라 예측의 비중이 달라진다.

이제 다음 round에서 현재 round의 Amount of Say값을 가지고 sample weight를 업데이트 할 것이다. sample weight 업데이트로 다음 round에서 새로운 stump가 선정이 될 것이다.

2.4 세번째 round

sample weight

구분#1#2#3#4#5#6#7#8
S.W.0.029160840.029160840.029160840.204124450.029160840.029160840.174963290.17496329
S.W. Norm.0.041666950.041666950.041666950.291666680.041666950.041666950.249999270.24999927

stump

Yes로 예측한 3개중 오류 비율은 전체의 0.083이고 No로 예측한 5개중 오류 비율은 전체의 0.125이다.

예측값

y_predy_pred_AOS
-1-0.667494396576355
-1-0.667494396576355
-1-0.667494396576355
10.667494396576355
10.667494396576355
10.667494396576355
-1-0.667494396576355
-1-0.667494396576355

Amount of Error
Total Error는 0.083 + 0.125 = 0.208, Amount of Say는 0.6674943965763551이다. 이 값을 예측값에도 반영한다.

2.5 네번째 round

sample weight

구분#1#2#3#4#5#6#7#8
S.W.0.081223350.081223350.081223350.149622770.081223350.081223350.128247710.12824771
S.W. Norm.0.099999820.099999820.099999820.18421120.099999820.099999820.157894840.15789484

stump

Yes로 예측한 3개중 오류 비율은 전체의 0.0이고 No로 예측한 5개중 오류 비율은 전체의 0.184이다.

예측값

y_predy_pred_AOS
10.744034190435884
10.744034190435884
10.744034190435884
-1-0.744034190435884
-1-0.744034190435884
-1-0.744034190435884
-1-0.744034190435884
-1-0.744034190435884

Amount of Error
Total Error는 0.0 + 0.184 = 0.184, Amount of Say는 0.744034190435884이다. 이 값을 예측값에도 반영한다.

2.6 다섯번째 round

sample weight

구분#1#2#3#4#5#6#7#8
S.W.0.047519220.047519220.047519220.387655520.047519220.047519220.075030530.07503053
S.W. Norm.0.061290390.061290390.061290390.499998950.061290390.061290390.096774540.09677454

stump

Yes로 예측한 6개중 오류 비율은 전체의 0.194이고 No로 예측한 2개중 오류 비율은 전체의 0.0이다.

예측값

y_predy_pred_AOS
10.7135539843362075
10.7135539843362075
10.7135539843362075
10.7135539843362075
-1-0.7135539843362075
-1-0.7135539843362075
10.7135539843362075
10.7135539843362075

Amount of Error
Total Error는 0.194 + 0.0 = 0.194, Amount of Say는 0.7135539843362075이다. 이 값을 예측값에도 반영한다.

2.7 여섯번째 round

sample weight

구분#1#2#3#4#5#6#7#8
S.W.0.030026160.030026160.030026160.244949490.030026160.030026160.197539380.19753938
S.W. Norm.0.038000150.038000150.038000150.310000220.038000150.038000150.249999510.24999951

stump

Yes로 예측한 3개중 오류 비율은 전체의 0.076이고 No로 예측한 5개중 오류 비율은 전체의 0.114이다.

예측값

y_predy_pred_AOS
-1-0.7250006146288984
-1-0.7250006146288984
-1-0.7250006146288984
10.7250006146288984
10.7250006146288984
10.7250006146288984
-1-0.7250006146288984
-1-0.7250006146288984

Amount of Error
Total Error는 0.076 + 0.114 = 0.19, Amount of Say는 0.7250006146288984이다. 이 값을 예측값에도 반영한다.

stump의 갯수는 hyper parameter로서 사용자가 정한다고 하였는데 본 예제에서는 6개로 정하고 테스트하였다. 따라서 6개의 stump를 생성시켜 모든 round를 다 했으므로 더 이상의 round는 없다!

2.8 결과 예측

6개의 round를 통해서 모두 6개의 예측 결과가 나왔는데 각각이 가중치가 다 다르다. 이를 정리해 보면,

6개 stump의 예측 결과

clf1clf2clf3clf4clf5clf6TARGET
111-111-11
211-111-11
311-111-11
4-111-1111
5-1-11-1-11-1
6-1-1-1-1-11-1
7-11-1-11-1-1
8-11-1-11-1-1

Amount of Say를 반영한 예측 결과

clf1clf2clf3clf4clf5clf6SUM
10.9729520.895875-0.6674940.7440340.713554-0.7250011.933920
20.9729520.895875-0.6674940.7440340.713554-0.7250011.933920
30.9729520.895875-0.6674940.7440340.713554-0.7250011.933920
4-0.9729520.8958750.667494-0.7440340.7135540.7250011.284938
5-0.972952-0.8958750.667494-0.744034-0.7135540.725001-1.933920
6-0.972952-0.8958750.667494-0.744034-0.7135540.725001-1.933920
7-0.9729520.895875-0.667494-0.7440340.713554-0.725001-1.500052
8-0.9729520.895875-0.667494-0.7440340.713554-0.725001-1.500052

6번의 round를 거쳐 6개의 stump가 예측한 결과와 이 결과에 Amount of Say라는 가중치를 곱한 결과를 표로 정리하였다. 1번 sample의 경우 stump(classifier)1에서는 0.972이고 2번 stump에서는 0.895, 마지막 6번 stump에서는 -0.725이다. +값은 YES로 예측한 것이고 -값은 NO로 예측한 것이다. 이들 값을 모두 더하면 +값이 더 많거나 가중치가 더 높으면 합계가 양수로 나올 것이고 -값이 더 많거나 가중치가 더 높으면 합계가 음수로 나올 것이다. 1번 sample의 경우는 합계가 1.933으로 양수가 나와서 YES로 분류하게 된다. 나머지 sample도 마찬가지 이다. 즉 각각의 sample의 예측값의 합계가 양수이면 1로 음수이면 -1로 만들면 실제 y의 값과 바로 비교가 가능한데 본 예제에서는 1, 1, 1, 1, -1, -1, -1, -1로 실제 y값과 동일해서 예측이 정확하게 맞았다고 볼 수 있다.

profile
바쁘게 부지런하게 논리적으로 살자!!!

0개의 댓글