sns.pairplot(iris_pd, hue='species')
sns.pairplot(data=iris_pd, vars=['petal length (cm)', 'petal width (cm)'], hue='species', height=4)
- iris 3종에 대한 구분이 가능할 것으로 보인다.
from sklearn.tree import DecisionTreeClassifier
iris_tree = DecisionTreeClassifier()
---
# fit(학습할 데이터) # 정답 # petal length & width
iris_tree.fit(iris.data[:, 2:], iris.target)
# petal length&width 150개 정답과 함께 학습 완료한 iris_tree에 150개 값만 가지고 예측하는 성능 확인
from sklearn.metrics import accuracy_score
y_pred_tr = iris_tree.predict(iris.data[:, 2:])
y_pred_tr
iris.target
from sklearn.tree import plot_tree
plt.figure(figsize=(12,8))
plot_tree(iris_tree);
# !pip install mlxtend
from mlxtend.plotting import plot_decision_regions
plt.figure(figsize=(14,8))
plot_decision_regions(X=iris.data[:, 2:], y=iris.target, clf=iris_tree, legend=2)
plt.show()
y_pred_tr = iris_tree.predict(iris.data[:, 2:])
accuracy_score(iris.target, y_pred_tr)
from sklearn.datasets import load_iris
import pandas as pd
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(features, labels,
test_size=0.2,
stratify=labels, # 항목당 추출 비율 동일하게
random_state=13)
from sklearn.tree import DecisionTreeClassifier
iris_tree = DecisionTreeClassifier(max_depth=2, random_state=13)
iris_tree.fit(X_train, y_train)
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
plt.figure(figsize=(12,8))
plot_tree(iris_tree);
# train 데이터에 대한 accuracy 확인
from sklearn.metrics import accuracy_score
y_pred_tr = iris_tree.predict(iris.data[:, 2:])
accuracy_score(iris.target, y_pred_tr)
from mlxtend.plotting import plot_decision_regions
plt.figure(figsize=(14,8))
plot_decision_regions(X=X_train, y=y_train, clf=iris_tree, legend=2)
plt.show()
# test 데이터에 대한 accuracy 확인
y_pred_test = iris_tree.predict(X_test)
accuracy_score(y_test, y_pred_test)
scatter_highlight_kwargs = {'s':150, 'label':'Test data', 'alpha':0.9}
scatter_kwargs = {'s':120, 'edgecolor':None, 'alpha':0.9}
plt.figure(figsize=(12,8))
plot_decision_regions(X=features, y=labels,
X_highlight=X_test, clf=iris_tree, legend=2,
scatter_highlight_kwargs=scatter_highlight_kwargs,
scatter_kwargs=scatter_kwargs,
contourf_kwargs={'alpha':0.2}
)
features = iris.data
labels = iris.target
X_train, X_test, y_train, y_test = train_test_split(features, labels,
test_size=0.2,
stratify=labels, # 항목당 추출 비율 동일하게
random_state=13)
iris_tree = DecisionTreeClassifier(max_depth=2, random_state=13)
iris_tree.fit(X_train, y_train)
plt.figure(figsize=(12,8))
plot_tree(iris_tree);