# mglearn 설치 필요
# !pip install mglearn
Collecting mglearn Downloading mglearn-0.1.9.tar.gz (540 kB) |████████████████████████████████| 540 kB 7.0 MB/s Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mglearn) (1.21.6) Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from mglearn) (3.2.2) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from mglearn) (1.0.2) Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from mglearn) (1.3.5) Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from mglearn) (7.1.2) Requirement already satisfied: cycler in /usr/local/lib/python3.7/dist-packages (from mglearn) (0.11.0) Requirement already satisfied: imageio in /usr/local/lib/python3.7/dist-packages (from mglearn) (2.4.1) Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from mglearn) (1.1.0) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mglearn) (1.4.2) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mglearn) (3.0.8) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mglearn) (2.8.2) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from kiwisolver>=1.0.1->matplotlib->mglearn) (4.2.0) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->mglearn) (1.15.0) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->mglearn) (2022.1) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->mglearn) (3.1.0) Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->mglearn) (1.4.1) Building wheels for collected packages: mglearn Building wheel for mglearn (setup.py) ... done Created wheel for mglearn: filename=mglearn-0.1.9-py2.py3-none-any.whl size=582639 sha256=f205d10bbb5511c252273e58983f75aeb6975a0d870335da6ef8d39dc9684b4f Stored in directory: /root/.cache/pip/wheels/f1/17/e1/1720d6dcd70187b6b6c3750cb3508798f2b1d57c9d3214b08b Successfully built mglearn Installing collected packages: mglearn Successfully installed mglearn-0.1.9
import matplotlib.pyplot as plt
import mglearn
plt.figure(figsize=(10,10))
mglearn.plots.plot_animal_tree()
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
import seaborn as sns
cancer = load_breast_cancer()
X = cancer.data
y = cancer.target
X_train, X_test, y_train, y_test = train_test_split(X, y,
stratify=cancer.target,
test_size = 0.3,
random_state=77)
tree = DecisionTreeClassifier(max_depth=2, random_state=0)
tree.fit(X_train, y_train)
print("훈련 세트 정확도 : {:.3f}".format(tree.score(X_train, y_train)))
print("테스트 세트 정확도 : {:.3f}".format(tree.score(X_test, y_test)))
훈련 세트 정확도 : 0.972 테스트 세트 정확도 : 0.912
for i in range(1,7,1):
tree = DecisionTreeClassifier(max_depth=i, random_state=0)
tree.fit(X_train, y_train)
print(f"max_depth : {i}")
print("훈련 세트 정확도 : {:.3f}".format(tree.score(X_train, y_train)))
print("테스트 세트 정확도 : {:.3f}".format(tree.score(X_test, y_test)))
max_depth : 1 훈련 세트 정확도 : 0.932 테스트 세트 정확도 : 0.883 max_depth : 2 훈련 세트 정확도 : 0.972 테스트 세트 정확도 : 0.912 max_depth : 3 훈련 세트 정확도 : 0.982 테스트 세트 정확도 : 0.906 max_depth : 4 훈련 세트 정확도 : 0.985 테스트 세트 정확도 : 0.906 max_depth : 5 훈련 세트 정확도 : 0.992 테스트 세트 정확도 : 0.889 max_depth : 6 훈련 세트 정확도 : 0.997 테스트 세트 정확도 : 0.901
tree = DecisionTreeClassifier(max_depth=2, random_state=0)
tree.fit(X_train, y_train)
print(f"max_depth : {i}")
print("훈련 세트 정확도 : {:.3f}".format(tree.score(X_train, y_train)))
print("테스트 세트 정확도 : {:.3f}".format(tree.score(X_test, y_test)))
max_depth : 6 훈련 세트 정확도 : 0.972 테스트 세트 정확도 : 0.912
from sklearn.tree import export_graphviz
import graphviz
export_graphviz(tree,
out_file="tree.dot",
class_names=['악성', '양성'],
feature_names = cancer.feature_names,
impurity = False, # gini 계수
filled=True) # color
with open("tree.dot") as f:
dot_graph = f.read()
display(graphviz.Source(dot_graph))
특성 중요도 : 이 값은 0과 1사이의 숫자.
특성의 featureimportance 값이 낮다고 해서 특성이 유용하지 않다는 것이 아니다.
import numpy as np
def plot_feature_imp_cancer(model):
n_features = cancer.data.shape[1]
imp = model.feature_importances_
plt.barh(range(n_features) , imp, align='center')
plt.yticks(np.arange(n_features), cancer.feature_names)
plt.xlabel("feature importance")
plt.ylabel("feature")
plt.ylim(-1, n_features)
plt.figure(figsize=(12,12))
plot_feature_imp_cancer(tree)