본문 바로가기

Computer Science/Machine Learning

Cross Validation, Grid Search with GridSearchCV

우리가 의사결정 나무를 사용하면, 정확도가 아주 높지는 않습니다. 대략 테스트 셋의 60-80% 정도의 accuracy를 보이죠. decision tree를 구성하는 파라미터는 아주 많은데, 각각의 파라미터에 어떤 값을 넣어주어야 할까요? 이때 각각의 파라미터를 바꿔가며 모델을 여러가 맨들어 최적의 파라미터를 찾아주는 함수가 있습니다. 바로 Grid Search 입니다. 게다가 train set과 test set을 한 번 만 나누지 않고 Cross Validation을 사용해서 각 매개변수의 성능을 평가할 수도 있습니다..

 

데이터는 저번 포스팅과 같은 Breast cancer 데이터를 사용할 것입니다. Decision tree Classifier와 Accuracy score 라이브러리도 불러옵니다. train test split 함수는 저번 포스팅에서 보였었고요, Grid Search CV 함수가 트리의 최적 depth를 계산할 것 입니다.

 

 

import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

from sklearn.model_selection import train_test_split, GridSearchCV

 

X, y = load_breast_cancer(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size = 0.2, random_state = 6)

 

데이터를 불러오고, train과 test set으로 쪼개는 split 함수를 호출합니다. test set은 전체의 20%로 설정, random state를 설정해 결과가 동일하게 고정합니다.

 

clf_try = DecisionTreeClassifier(random_state = 5)
n_folds = 5
range_depth = np.linspace(1, 10, 10, dtype='int')

 

 

clf_try라는 변수에 의사결정 나무를 저장합니다.

n_folds라는 변수에는 몇번 cross validation 할건지를 설정했습니다.

range_depth라는 변수에는 벡터를 저장했는데, 1부터 시작해 10에서 멈추는 10개의 정수를 저장했습니다. 저장된 값은 [ 1 2 3 4 5 6 7 8 9 10 ]입니다. list를 사용해도 무방합니다.

 

grid_param = {'max_depth': range_depth}

grid_clf = GridSearchCV(clf_try, grid_param, 
                        scoring = 'accuracy', cv = n_folds)
grid_clf.fit(X_train, y_train)

 

grid_param이라는 변수에 '딕셔너리' 형태로 파라미터들을 저장합니다. 이때 파라미터 명은 반드시 string이어야 하고,  그 값들은 list 형태로 정의합니다.

*딕셔너리 : { 파라미터의 이름을 정의 : 파라미터가 가질 수 있는 값들, ... } 형식으로 정의합니다.

grid_clf 객체를 생성해 GridSearchCV 함수 호출, 1. Classifiers, 2. 파라미터를 저장한 딕셔너리, 3. optional) scoring을 정의할 수 있음 (디폴트=없음), 4. Cross Validation을 몇번 수행할 것인지 를 지정합니다.

 

언뜻보면 DecisionTreeClassifier 함수와 비슷해보입니다. 다만 GridSearchCV는 다른 Classifiers 간에 적용되는 함수라는 점, 각각의 파라미터를 바꿔가며 계산한다는 점, cross validation도 수행한다는 점이 다르겠습니다. 그러면 return 값이 accuracy가 됩니다.

 

그리고 마지막, .fit 함수를 사용해 train 데이터에 학습시킵니다.

 

그상태로 grid_clf.cv_results_를 출력하면 다양한 숫자들이 출력되는데요. 이름이 mean_fit_time(시간이 얼마나 걸렸냐) 인 것도 있고, std_fit_time도 있고, ... 그 중 우리가 관심있는 것이 mean_test_score과 std_test_score이기 때문에 따로 변수를 생성해 저장하도록 하겠습니다.

 

scores = grid_clf.cv_results_['mean_test_score']
scores_std = grid_clf.cv_results_['std_test_score']

 

출력결과 [0.91428571 0.93186813 0.94065934 0.92527473 0.93846154 0.92527473
 0.92307692 0.92307692 0.92307692 0.92307692]
[0.01281528 0.01457857 0.01120664 0.03061184 0.01916    0.02981244
 0.03029461 0.03029461 0.03029461 0.03029461]

 

평균 뿐만 아니라 표준편차도 decision making에 중요한 역할을 합니다.

 

 

이제 결과가 거의 다 나왔습니다. 평균 score가 가장 높은 세번째 값 0.94065934이 우리가 얻을 수 있는 최대 정확도라고 할 수 있겠습니다.

 

다음으로는 이 score들을 plot해보고자 하는데요.

 

plt.figure()
plt.plot(range_depth, scores, 'b')
plt.plot(range_depth, scores + scores_std, 'b--')
plt.plot(range_depth, scores - scores_std, 'b--')
plt.xlabel('max_depth')
plt.ylabel('CV accuracy score')
plt.show()

best_grid = grid_clf.best_params_.copy()
print('best settings: ', best_grid)

clf = grid_clf.best_estimator_

y_pred_train = clf.predict(X_train)
y_pred_test = clf.predict(X_test)
acc_train = accuracy_score(y_pred_train, y_train)
acc_test = accuracy_score(y_pred_test, y_test)
print("train/test accuracy %.2f/%.2f" % (acc_train, acc_test))

 

 

 

 

x축이 depth이고 y축이 정확도인 그래프입니다. depth = 3인 지점에서 정확도가 가장 높고, 편차도 크지 않습니다.

train과 test 예측의 정확도를 출력해보면, 0.97과 0.92가 출력됩니다.

 

또한 우리는 어떤 parameter를 고를 것인지도 결정할 수 있습니다. 위에서는 depth가 3일때 높은 score와 낮은 표준편차를 보였죠. 즉 그 모델이 가장 좋은 성능을 보이고, 5 CV folds 일때 가장 안정적이라는 뜻입니다. 

이런 식으로 다른 parameter 들을 딕셔너리에 추가하면서 모델을 최적화 시켜볼 수 있습니다. 우리가 entropy나 gini 값을 트리 구성에 추가하고 싶다고 가정한다면, 다음과 같이 추가합니다. 

 

grid_param = {'max_depth': range_depth,
              'criterion': ['gini', 'entropy']}

 

이렇게 작성하면 모델이 테스트 해야하는 숫자가 늘어납니다. 

 

 

추후 추가

'Computer Science > Machine Learning' 카테고리의 다른 글

Learning Curves  (0) 2020.12.28
Getting started with the Scikit-Learn library  (0) 2020.12.14
Dendrograms and Heat Plots  (0) 2020.12.14
Hierarchical Clustering에 대하여  (0) 2020.12.13
Random Forest 배워보기  (0) 2020.07.10