Machine Learning
Desicion Tree (결정 트리)
돼지표
2022. 6. 14. 00:26
Import libraries
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_validate
Hyperparameter tuning
param={
"min_samples_leaf":[1, 2, 3, 4, 5, 6, 7, 8, 9],
"max_depth":[2, 3, 4, 5, 6, None],
"min_samples_split":[2, 3, 4, 5, 6, 7, 8, 9, 10]
}
gs = GridSearchCV(DecisionTreeClassifier(), param, n_jobs =-1)
gs.fit(X, Y)
dt = gs.best_estimator_
print(dt.score(test_x, test_y))
print(gs.best_params_)
Best model
param = {'max_depth': 6, 'min_samples_leaf': 1, 'min_samples_split': 4}
dt = DecisionTreeClassifier(**param)
dt.fit(train_x, train_y)
scores = cross_validate(dt, train_x, train_y, return_train_score=True, n_jobs=-1)
print("cross_validate")
print(np.mean(scores['train_score']), np.mean(scores['test_score']))
print("test score")
print(dt.score(test_x, test_y))