붓꽃의 3가지 품종 versicolor, virginica, setosa를 꽃잎(petal), 꽃받침(sepal)의 길이, 너비로 구분해보자
from sklearn.datasets import load_iris
iris = load_iris()
iris.keys()
iris.feature_names
iris_pd = pd.DataFrame(iris.data, columns = iris.feature_names)
iris_pd.head()
#iris.target은 붓꽃 품종에 따라 0, 1, 2로 분류해놓은 값
iris_pd['species'] = iris.target
iris_pd.head()
import seaborn as sns
import matplotlib.pyplot as plt
sns.pairplot(iris_pd, hue = 'species' )
from sklearn.tree import DecisionTreeClassifier
from sklearm.metrics import accuracy_score
iris_tree = DecisionTreeClassifier()
## 학습
iris_tree.fit(iris.data[:, 2:], iris.target)
## 예측
y_pred_tr = iris_tree.predict(iris.data[:, 2:])
accuracy_score(iris.target, y_pred_tr)
위 성능은 이미 답을 알려준 상태에서 문제를 풀어보고 테스트한 성능이므로 신뢰도가 제한적이다. 성능에 대한 평가가 유효하기 위해서는 학습용과 테스트용 데이터 셋을 분리할 필요가 있다.
from sklearn.model_selection import train_test_split
feature = iris.data[:, 2:]
labels = iris.target
# test_size= 0.2 -> 8:2 비율로 학습용과 테스트용 데이터 셋 분리
# stratify = labels -> labels 기준으로 분포 고르게 분리
X_train, X_test, y_train, y_test = train_test_split(feature, labels,\
test_size = 0.2,\
random_state = 13, stratify= labels)
X_train.shape, X_test.shape
import numpy as np
np.unique(y_test, return_counts= True)
## DecisionTreeClassifier의 몇가지 옵션(하이퍼파라미터)을 다르게 설정
iris_tree = DecisionTreeClassifier(max_depth = 2, random_state = 13)
iris_tree.fit(X_train, y_train)
y_pred_tr = iris_tree.predict(iris.data[:, 2:])
y_pred_test = iris_tree.predict(X_test)
accuracy_score(iris.target, y_pred_tr)
accuracy_score(y_test, y_pred_test)
#!pip install mlxtend
from mlxtend.plotting import plot_decision_regions
# 각각 사이즈, 라벨, 투명도
scatter_highlight_kwargs = {'s': 150, 'label': 'Test data', 'alpha' :0.9}
scatter_kwargs = {'s': 120, 'edgecolor': None, 'alpha' :0.9}
plt.figure(figsize = (12,8))
plot_decision_regions\
(X=feature, y= labels, X_highlight = X_test, clf = iris_tree, \
legend = 2, scatter_highlight_kwargs = scatter_highlight_kwargs, \
scatter_kwargs = scatter_kwargs, contourf_kwargs = {'alpha':0.2})
plt.show()
from sklearn.tree import plot_tree
plt.figure(figsize=(12, 8))
plot_tree(iris_tree);
분할하기 전
분할 후
엔트로피가 내려갔으므로 분할하는 것이 나음.