GridSearchCV

jh_k·2023년 2월 14일
0

데이터 분석

목록 보기
15/17

sklearn.model_selection.GridSearchCV
class sklearn.model_selection.GridSearchCV(estimator, param_grid, , scoring=None, n_jobs=None, refit=True, cv=None, verbose=0, pre_dispatch='2n_jobs', error_score=nan, return_train_score=False)

  • estimato : 사용할 알고리즘
  • param_grid : dict형태로 검정하고 싶은 하이퍼 파라미터값
  • scoring : 모형평가 기준 ex) "accuracy"
  • cv : 교차검증할 폴드 , 횟수
  • refit : 디폴트 True -> 가장 좋은 파라미터 설정으로 재학습 시킴.
  • return_train_score : 디폴트 False ->True 이면 train에 대한 score를 알려줌

kfold 교차검증과 동시에 최적의 하이퍼 파라미터를 찾아 가장 최적의 모델을 찾기위한 라이브러리로
refit을 True로 하면 바로 모델에 적합한 bestestimator을 통하여 적합해 볼 수 있다.


from sklearn.datasets import load_iris
from sklearn.model_selection import GridSearchCV,train_test_split
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
iris = load_iris()
x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,test_size=0.2,random_state=121)

dt_clf = DecisionTreeClassifier(random_state=156)

## parameter 들을 dictionary 형태로 설정
param = {
    "max_depth":[1,2,3],
    "min_samples_split":[2,3]
}

grid_dt = GridSearchCV(dt_clf,param_grid=param,cv=3,refit=True)
grid_dt.fit(x_train,y_train)

# GridSearch 결과 추출 하여 DataFrame으로 변환 attribute들을 컬럼으로 하여 나타냄
scores_df = pd.DataFrame(grid_dt.cv_results_)
scores_df.iloc[:,4:]
param_max_depth	param_min_samples_split	params	split0_test_score	split1_test_score	split2_test_score	mean_test_score	std_test_score	rank_test_score
        0	1	2	{'max_depth': 1, 'min_samples_split': 2}	0.700	0.7	0.70	0.700000	1.110223e-16	5
        1	1	3	{'max_depth': 1, 'min_samples_split': 3}	0.700	0.7	0.70	0.700000	1.110223e-16	5
        2	2	2	{'max_depth': 2, 'min_samples_split': 2}	0.925	1.0	0.95	0.958333	3.118048e-02	3
        3	2	3	{'max_depth': 2, 'min_samples_split': 3}	0.925	1.0	0.95	0.958333	3.118048e-02	3
        4	3	2	{'max_depth': 3, 'min_samples_split': 2}	0.975	1.0	0.95	0.975000	2.041241e-02	1
        5	3	3	{'max_depth': 3, 'min_samples_split': 3}	0.975	1.0	0.95	0.975000	2.041241e-02	1
        
print('GridSearchCV 최적 파라미터:',grid_dt.best_params_)
print(f'GridSearchCV 최고 정확도: {grid_dt.best_score_}')

# GridSearchCV의 refit으로 이미 학습이 된 estimator 반환
# 최적의 하이퍼파라미터로 DecisionTree에 적용된 모형
estimator = grid_dt.best_estimator_

## GridSearchCV의 best_estimator_로 학습된 모형의 예측
pred = estimator.predict(x_test)
accuracy = accuracy_score(y_test,pred)
print(f'테스트 데이터 세트 정확도 {accuracy}')

GridSearchCV 최적 파라미터: {'max_depth': 3, 'min_samples_split': 2}
GridSearchCV 최고 정확도: 0.975
테스트 데이터 세트 정확도 0.9666666666666667

profile
Just Enjoy Yourself

0개의 댓글