sklearn.model_selection.GridSearchCV
class sklearn.model_selection.GridSearchCV(estimator, param_grid, , scoring=None, n_jobs=None, refit=True, cv=None, verbose=0, pre_dispatch='2n_jobs', error_score=nan, return_train_score=False)
kfold 교차검증과 동시에 최적의 하이퍼 파라미터를 찾아 가장 최적의 모델을 찾기위한 라이브러리로
refit을 True로 하면 바로 모델에 적합한 bestestimator을 통하여 적합해 볼 수 있다.
from sklearn.datasets import load_iris
from sklearn.model_selection import GridSearchCV,train_test_split
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
iris = load_iris()
x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,test_size=0.2,random_state=121)
dt_clf = DecisionTreeClassifier(random_state=156)
## parameter 들을 dictionary 형태로 설정
param = {
"max_depth":[1,2,3],
"min_samples_split":[2,3]
}
grid_dt = GridSearchCV(dt_clf,param_grid=param,cv=3,refit=True)
grid_dt.fit(x_train,y_train)
# GridSearch 결과 추출 하여 DataFrame으로 변환 attribute들을 컬럼으로 하여 나타냄
scores_df = pd.DataFrame(grid_dt.cv_results_)
scores_df.iloc[:,4:]
param_max_depth param_min_samples_split params split0_test_score split1_test_score split2_test_score mean_test_score std_test_score rank_test_score
0 1 2 {'max_depth': 1, 'min_samples_split': 2} 0.700 0.7 0.70 0.700000 1.110223e-16 5
1 1 3 {'max_depth': 1, 'min_samples_split': 3} 0.700 0.7 0.70 0.700000 1.110223e-16 5
2 2 2 {'max_depth': 2, 'min_samples_split': 2} 0.925 1.0 0.95 0.958333 3.118048e-02 3
3 2 3 {'max_depth': 2, 'min_samples_split': 3} 0.925 1.0 0.95 0.958333 3.118048e-02 3
4 3 2 {'max_depth': 3, 'min_samples_split': 2} 0.975 1.0 0.95 0.975000 2.041241e-02 1
5 3 3 {'max_depth': 3, 'min_samples_split': 3} 0.975 1.0 0.95 0.975000 2.041241e-02 1
print('GridSearchCV 최적 파라미터:',grid_dt.best_params_)
print(f'GridSearchCV 최고 정확도: {grid_dt.best_score_}')
# GridSearchCV의 refit으로 이미 학습이 된 estimator 반환
# 최적의 하이퍼파라미터로 DecisionTree에 적용된 모형
estimator = grid_dt.best_estimator_
## GridSearchCV의 best_estimator_로 학습된 모형의 예측
pred = estimator.predict(x_test)
accuracy = accuracy_score(y_test,pred)
print(f'테스트 데이터 세트 정확도 {accuracy}')
GridSearchCV 최적 파라미터: {'max_depth': 3, 'min_samples_split': 2}
GridSearchCV 최고 정확도: 0.975
테스트 데이터 세트 정확도 0.9666666666666667