Random Forest

김혜성·2021년 2월 18일
0

Machine Learning

목록 보기
4/10
post-custom-banner

Random Forest의 정의

나무가 모이면 뭐가 될까요?
숲이 되죠?
Decision Tree(결정 트리)가 모이면 뭐가 될까요?
Random Forest(랜덤 포레스트)가 됩니다.
허허
결정 트리 하나로도 학습시킬 수 있지만 이전 포스트에서 언급했듯이 오버피팅이 발생할 수 있다. 여러 결정 트리를 통해 랜덤 포레스트를 만들면 오버피팅을 줄일 수 있다.
랜덤 포레스트의 개별적인 분류 알고리즘은 결정 트리이지만 각 분류기가 학습하는 데이터 세트는 전체 데이터에서 일부가 중첩되도록 샘플링된 데이터 세트이다. 이처럼 중첩되도록 분류하기에 앙상블 기법 중 Bootstrap(부트스트랩)이 사용된 Bagging(배깅)에 속한다.

랜덤 포레스트 과정

random forest
위의 그림은 랜덤 포레스트를 시각화한 것이다.
트리0부터 트리4까지의 결정 트리가 만든 Boundary들의 평균을 내어 랜덤 포레스트의 경계가 결정된다.

많은 Feature를 담은 하나의 깊은 트리를 만들면 오버피팅이 심하기에 무작위로 몇가지의 Feature를 뽑아서 얕은 트리를 여러 개 만들어 나온 여러 예측값들 중 가장 빈도가 높은 값(분류) 또는 평균값(회귀)을 최종 예측값으로 정한다.

실습

class sklearn.ensemble.RandomForestClassifier(n_estimators=100, *, criterion='gini',
max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0,
max_features='auto', max_leaf_nodes=None, min_impurity_decrease=0.0,
min_impurity_split=None, bootstrap=True, oob_score=False, n_jobs=None,
random_state=None, verbose=0, warm_start=False, class_weight=None,
ccp_alpha=0.0, max_samples=None)

Parameter

  • n_estimators: 랜덤 포레스트 안의 결정 트리의 개수. 개수가 많을수록 Decision Boundary가 깔끔하게 나오지만 메모리와 학습시간이 증가한다.
  • max_features: 무작위로 선택할 Feature의 개수. 일반적으로 default값을 사용한다. bootstrap = True일 경우 복원추출로 택한다. max_features가 클수록 각 트리들의 예측값이 비슷해지고, 작을수록 달라져서 오버피팅이 줄어든다.

지금까지 다뤄온 분류기들은 상당히 많은 hyper parameter를 가지고 있기에 하나하나 맞춰보며 시도하기엔 한계가 있다. 이럴 때 GridSearchCV를 통해 최적의 hyper parameter로 튜닝할 수도 있다.

from sklearn.model_selection import GridSearchCV

#feature 중요도도 그려볼 수 있다.
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

featureimportances_values = final_bagging_rf_ensemble.feature_importances
feature_importances = pd.Series(feature_importances_values, index=cancer_dataset_df.columns)
feature_importances_top20 = feature_importances.sort_values(ascending=False)[:20] #중요한 순서로 정렬 후 top 20개만 뽑아 시각화

plt.figure(figsize=(8, 6))
plt.title('Feature Importances Top 20')
sns.barplot(x=feature_importances_top20, y=feature_importances_top20.index)
plt.show()

위의 코드는 GridSearchCV를 통해서 가장 중요한 feature 20개를 뽑아서 시각화하는 코드이다.

출처

profile
똘멩이
post-custom-banner

0개의 댓글