GridSearchCV cross validation

Code implementation (based on logistic regression algorithm):

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Sat Sep  1 11:54:48 2018
 4 
 5 @author: zhen
 6 
 7     Cross validation 8 """
 9 import numpy as np
10 from sklearn import datasets
11 from sklearn.linear_model import LogisticRegression
12 from sklearn.model_selection import GridSearchCV
13 import matplotlib.pyplot as plt
14 
15 iris = datasets.load_iris()
16 x = iris['data'][:, 3:]
17 y = iris['target']
18 
19 
20 def report(results, n_top=3):
21     for i in range(1, n_top + 1):
22         candidates = np.flatnonzero(results['rank_test_score'] == i)
23         for candidate in candidates:
24             print("Model with rank: {0}".format(i))
25             print("Mean validation score: {0:.3f} (std: {1:.3f})".format(
26                    results['mean_test_score'][candidate],
27                    results['std_test_score'][candidate]))
28             print("Parameters: {0}".format(results['params'][candidate]))
29             print("")
30 
31 
32 param_grid = {"tol":[1e-4, 1e-3,1e-2], "C":[0.4, 0.6, 0.8]}
33 
34 log_reg = LogisticRegression(multi_class='ovr', solver='sag')
35 # 70 percent off cross validation.
36 grid_search = GridSearchCV(log_reg, param_grid=param_grid, cv=3)
37 grid_search.fit(x, y)
38 
39 report(grid_search.cv_results_)
40 
41 x_new = np.linspace(0, 3, 1000).reshape(-1, 1)
42 y_proba = grid_search.predict_proba(x_new)
43 y_hat = grid_search.predict(x_new)
44 
45 plt.plot(x_new, y_proba[:, 2], 'g-', label='Iris-Virginica')
46 plt.plot(x_new, y_proba[:, 1], 'r-', label='Iris-Versicolour')
47 plt.plot(x_new, y_proba[:, 0], 'b-', label='Iris-Setosa')
48 plt.show()
49 
50 print(grid_search.predict([[1.7], [1.5]]))

Result:

Summary: Cross-validation can be used to automatically train the models with set range parameters, and finally select the best parameters to train the model for forecasting, in order to achieve the best forecasting effect.

Leave a Reply

Your email address will not be published. Required fields are marked *