Machine Learning

Desicion Tree (결정 트리)

돼지표 2022. 6. 14. 00:26

Import libraries

 

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_validate

 

 

Hyperparameter tuning

param={
    "min_samples_leaf":[1, 2, 3, 4, 5, 6, 7, 8, 9],
    "max_depth":[2, 3, 4, 5, 6, None],
    "min_samples_split":[2, 3, 4, 5, 6, 7, 8, 9, 10]
}
gs = GridSearchCV(DecisionTreeClassifier(), param, n_jobs =-1)
gs.fit(X, Y)

dt = gs.best_estimator_
print(dt.score(test_x, test_y))

print(gs.best_params_)

 

Best model

param = {'max_depth': 6, 'min_samples_leaf': 1, 'min_samples_split': 4}
dt = DecisionTreeClassifier(**param)
dt.fit(train_x, train_y)

scores = cross_validate(dt, train_x, train_y, return_train_score=True, n_jobs=-1)
print("cross_validate")
print(np.mean(scores['train_score']), np.mean(scores['test_score']))
print("test score")
print(dt.score(test_x, test_y))