GregOliveira commited on
Commit
9605944
1 Parent(s): cedfa43
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import numpy as np
3
+ import pandas as pd
4
+ import gradio as gr
5
+ import joblib as jb
6
+ import os.path as path
7
+
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ # Loading features names
12
+ file_csv = path.join("model" ,"model_features.csv")
13
+ with open(file_csv) as f:
14
+ reader = csv.reader(f)
15
+ data = list(reader)
16
+
17
+ features = data[0]
18
+
19
+ # Creating a list with accident types
20
+ accident_type_list = [None,
21
+ "type_ATROPELAMENTO",
22
+ "type_CHOQUE",
23
+ "type_COLISÃO",
24
+ "type_OUTROS"]
25
+
26
+ # Loading the scaler
27
+ file_scaler_feridos = path.join("model" ,"scaler_feridos.pkl")
28
+ scaler_feridos = jb.load(file_scaler_feridos)
29
+
30
+ # Loading the model
31
+ file_model_feridos = path.join("model" ,"model_feridos.pkl")
32
+ model_feridos = jb.load(file_model_feridos)
33
+
34
+ def fit_inputs_injured(latitude,
35
+ longitude,
36
+ caminhao,
37
+ moto,
38
+ cars,
39
+ transport,
40
+ others,
41
+ holiday,
42
+ week_day,
43
+ hour_day,
44
+ accident_type) -> np.array:
45
+ """This function will process data input
46
+ from use to use in the model"""
47
+ input_dict = {col: False for col in features}
48
+
49
+ input_dict["latitude"] = latitude
50
+ input_dict["longitude"] = longitude
51
+ input_dict["caminhao"] = caminhao
52
+ input_dict["moto"] = moto
53
+ input_dict["cars"] = cars
54
+ input_dict["transport"] = transport
55
+ input_dict["others"] = others
56
+ input_dict["holiday"] = holiday
57
+
58
+ if week_day != 0:
59
+ input_dict["day_" + str(week_day)] = True
60
+
61
+ if hour_day != 0:
62
+ input_dict["hour_" + str(hour_day)] = True
63
+
64
+ if accident_type != 0:
65
+ input_dict[accident_type_list[accident_type]] = True
66
+
67
+ input_series = pd.Series(input_dict)
68
+
69
+ input_array = input_series.to_numpy().reshape(1,-1)
70
+
71
+ input_scaled = scaler_feridos.transform(input_array)
72
+
73
+ return input_scaled
74
+
75
+ def predict(
76
+ latitude,
77
+ longitude,
78
+ caminhao,
79
+ moto,
80
+ cars,
81
+ transport,
82
+ others,
83
+ holiday,
84
+ week_day,
85
+ hour_day,
86
+ accident_type) -> dict:
87
+ """This function will be call by gradio
88
+ when on submit action."""
89
+
90
+ input_to_predict = fit_inputs_injured(latitude,
91
+ longitude,
92
+ caminhao,
93
+ moto,
94
+ cars,
95
+ transport,
96
+ others,
97
+ holiday,
98
+ week_day,
99
+ hour_day,
100
+ accident_type)
101
+
102
+ predic_injured = model_feridos.predict_proba(input_to_predict)
103
+
104
+ return {"No": predic_injured[0][0], "Yes": predic_injured[0][1]}
105
+
106
+ demo = gr.Interface(
107
+ fn=predict,
108
+ inputs=[gr.Slider(
109
+ minimum=-31.054,
110
+ maximum=-29.054,
111
+ step=0.001,
112
+ value=-30.054,
113
+ label="Latitude"),
114
+ gr.Slider(
115
+ minimum=-52.196,
116
+ maximum=-50.196,
117
+ step=0.001,
118
+ value=-51.196,
119
+ label="Longitude"),
120
+ gr.Checkbox(label="Trucks involved?"),
121
+ gr.Checkbox(label="Motorcycle involved?"),
122
+ gr.Checkbox(label="Cars involved?"),
123
+ gr.Checkbox(label="Bus involved?"),
124
+ gr.Checkbox(label="Other vehicle (i.e. scooter) involved?"),
125
+ gr.Checkbox(label="Is holiday?"),
126
+ gr.Radio(
127
+ choices=["Sun", "Mon",
128
+ "Tue", "Wed",
129
+ "Thu", "Fri",
130
+ "Sat"],
131
+ type="index",
132
+ label="Day of Week"),
133
+ gr.Slider(
134
+ minimum=0,
135
+ maximum=23,
136
+ step=1,
137
+ label="Hour"),
138
+ gr.Dropdown(
139
+ choices=["Violent Collision",
140
+ "Running over",
141
+ "Shock",
142
+ "Collision",
143
+ "Other"],
144
+ type="index",
145
+ label="Accident type")],
146
+ outputs=gr.Label(
147
+ label="Are there people injured?"))
148
+
149
+ demo.launch()
model/model_features.csv ADDED
@@ -0,0 +1 @@
 
 
1
+ latitude,longitude,caminhao,moto,cars,transport,others,holiday,day_1,day_2,day_3,day_4,day_5,day_6,hour_1,hour_2,hour_3,hour_4,hour_5,hour_6,hour_7,hour_8,hour_9,hour_10,hour_11,hour_12,hour_13,hour_14,hour_15,hour_16,hour_17,hour_18,hour_19,hour_20,hour_21,hour_22,hour_23,type_ATROPELAMENTO,type_CHOQUE,type_COLIS�O,type_OUTROS
model/model_feridos.ipynb ADDED
@@ -0,0 +1,1659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 1. Introduction\n",
8
+ "\n",
9
+ "This notebook was written to train Porto Alegre Traffic Accidents Data after the first cleaning, processing, and transforming step. This was made in a notebook in the `data` folder. In truth, we will have 3 models.\n",
10
+ "\n",
11
+ "1. Predict the probability of injured people.\n",
12
+ "\n",
13
+ "2. Predict the probability of seriously injured people.\n",
14
+ "\n",
15
+ "3. Predict the probability of dead people in the event or after it.\n",
16
+ "\n",
17
+ "The path to training the models will be the same, just make some filtering on data and analyze the results properly."
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "# 2. Data Loading"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 1,
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "data": {
34
+ "text/html": [
35
+ "<div>\n",
36
+ "<style scoped>\n",
37
+ " .dataframe tbody tr th:only-of-type {\n",
38
+ " vertical-align: middle;\n",
39
+ " }\n",
40
+ "\n",
41
+ " .dataframe tbody tr th {\n",
42
+ " vertical-align: top;\n",
43
+ " }\n",
44
+ "\n",
45
+ " .dataframe thead th {\n",
46
+ " text-align: right;\n",
47
+ " }\n",
48
+ "</style>\n",
49
+ "<table border=\"1\" class=\"dataframe\">\n",
50
+ " <thead>\n",
51
+ " <tr style=\"text-align: right;\">\n",
52
+ " <th></th>\n",
53
+ " <th>0</th>\n",
54
+ " <th>1</th>\n",
55
+ " <th>2</th>\n",
56
+ " </tr>\n",
57
+ " </thead>\n",
58
+ " <tbody>\n",
59
+ " <tr>\n",
60
+ " <th>latitude</th>\n",
61
+ " <td>-30.009614</td>\n",
62
+ " <td>-30.0403</td>\n",
63
+ " <td>-30.069</td>\n",
64
+ " </tr>\n",
65
+ " <tr>\n",
66
+ " <th>longitude</th>\n",
67
+ " <td>-51.185581</td>\n",
68
+ " <td>-51.1958</td>\n",
69
+ " <td>-51.1437</td>\n",
70
+ " </tr>\n",
71
+ " <tr>\n",
72
+ " <th>feridos</th>\n",
73
+ " <td>True</td>\n",
74
+ " <td>True</td>\n",
75
+ " <td>True</td>\n",
76
+ " </tr>\n",
77
+ " <tr>\n",
78
+ " <th>feridos_gr</th>\n",
79
+ " <td>False</td>\n",
80
+ " <td>False</td>\n",
81
+ " <td>False</td>\n",
82
+ " </tr>\n",
83
+ " <tr>\n",
84
+ " <th>fatais</th>\n",
85
+ " <td>False</td>\n",
86
+ " <td>False</td>\n",
87
+ " <td>False</td>\n",
88
+ " </tr>\n",
89
+ " <tr>\n",
90
+ " <th>caminhao</th>\n",
91
+ " <td>False</td>\n",
92
+ " <td>False</td>\n",
93
+ " <td>False</td>\n",
94
+ " </tr>\n",
95
+ " <tr>\n",
96
+ " <th>moto</th>\n",
97
+ " <td>True</td>\n",
98
+ " <td>True</td>\n",
99
+ " <td>False</td>\n",
100
+ " </tr>\n",
101
+ " <tr>\n",
102
+ " <th>cars</th>\n",
103
+ " <td>True</td>\n",
104
+ " <td>True</td>\n",
105
+ " <td>True</td>\n",
106
+ " </tr>\n",
107
+ " <tr>\n",
108
+ " <th>transport</th>\n",
109
+ " <td>False</td>\n",
110
+ " <td>False</td>\n",
111
+ " <td>False</td>\n",
112
+ " </tr>\n",
113
+ " <tr>\n",
114
+ " <th>others</th>\n",
115
+ " <td>False</td>\n",
116
+ " <td>False</td>\n",
117
+ " <td>False</td>\n",
118
+ " </tr>\n",
119
+ " <tr>\n",
120
+ " <th>holiday</th>\n",
121
+ " <td>False</td>\n",
122
+ " <td>True</td>\n",
123
+ " <td>True</td>\n",
124
+ " </tr>\n",
125
+ " <tr>\n",
126
+ " <th>day_1</th>\n",
127
+ " <td>0</td>\n",
128
+ " <td>0</td>\n",
129
+ " <td>0</td>\n",
130
+ " </tr>\n",
131
+ " <tr>\n",
132
+ " <th>day_2</th>\n",
133
+ " <td>0</td>\n",
134
+ " <td>0</td>\n",
135
+ " <td>0</td>\n",
136
+ " </tr>\n",
137
+ " <tr>\n",
138
+ " <th>day_3</th>\n",
139
+ " <td>0</td>\n",
140
+ " <td>0</td>\n",
141
+ " <td>0</td>\n",
142
+ " </tr>\n",
143
+ " <tr>\n",
144
+ " <th>day_4</th>\n",
145
+ " <td>0</td>\n",
146
+ " <td>0</td>\n",
147
+ " <td>0</td>\n",
148
+ " </tr>\n",
149
+ " <tr>\n",
150
+ " <th>day_5</th>\n",
151
+ " <td>1</td>\n",
152
+ " <td>0</td>\n",
153
+ " <td>0</td>\n",
154
+ " </tr>\n",
155
+ " <tr>\n",
156
+ " <th>day_6</th>\n",
157
+ " <td>0</td>\n",
158
+ " <td>1</td>\n",
159
+ " <td>1</td>\n",
160
+ " </tr>\n",
161
+ " <tr>\n",
162
+ " <th>hour_1</th>\n",
163
+ " <td>0</td>\n",
164
+ " <td>0</td>\n",
165
+ " <td>0</td>\n",
166
+ " </tr>\n",
167
+ " <tr>\n",
168
+ " <th>hour_2</th>\n",
169
+ " <td>0</td>\n",
170
+ " <td>0</td>\n",
171
+ " <td>0</td>\n",
172
+ " </tr>\n",
173
+ " <tr>\n",
174
+ " <th>hour_3</th>\n",
175
+ " <td>0</td>\n",
176
+ " <td>0</td>\n",
177
+ " <td>0</td>\n",
178
+ " </tr>\n",
179
+ " <tr>\n",
180
+ " <th>hour_4</th>\n",
181
+ " <td>0</td>\n",
182
+ " <td>0</td>\n",
183
+ " <td>0</td>\n",
184
+ " </tr>\n",
185
+ " <tr>\n",
186
+ " <th>hour_5</th>\n",
187
+ " <td>0</td>\n",
188
+ " <td>0</td>\n",
189
+ " <td>0</td>\n",
190
+ " </tr>\n",
191
+ " <tr>\n",
192
+ " <th>hour_6</th>\n",
193
+ " <td>0</td>\n",
194
+ " <td>0</td>\n",
195
+ " <td>0</td>\n",
196
+ " </tr>\n",
197
+ " <tr>\n",
198
+ " <th>hour_7</th>\n",
199
+ " <td>0</td>\n",
200
+ " <td>0</td>\n",
201
+ " <td>0</td>\n",
202
+ " </tr>\n",
203
+ " <tr>\n",
204
+ " <th>hour_8</th>\n",
205
+ " <td>0</td>\n",
206
+ " <td>0</td>\n",
207
+ " <td>0</td>\n",
208
+ " </tr>\n",
209
+ " <tr>\n",
210
+ " <th>hour_9</th>\n",
211
+ " <td>0</td>\n",
212
+ " <td>0</td>\n",
213
+ " <td>0</td>\n",
214
+ " </tr>\n",
215
+ " <tr>\n",
216
+ " <th>hour_10</th>\n",
217
+ " <td>0</td>\n",
218
+ " <td>1</td>\n",
219
+ " <td>0</td>\n",
220
+ " </tr>\n",
221
+ " <tr>\n",
222
+ " <th>hour_11</th>\n",
223
+ " <td>0</td>\n",
224
+ " <td>0</td>\n",
225
+ " <td>0</td>\n",
226
+ " </tr>\n",
227
+ " <tr>\n",
228
+ " <th>hour_12</th>\n",
229
+ " <td>0</td>\n",
230
+ " <td>0</td>\n",
231
+ " <td>0</td>\n",
232
+ " </tr>\n",
233
+ " <tr>\n",
234
+ " <th>hour_13</th>\n",
235
+ " <td>0</td>\n",
236
+ " <td>0</td>\n",
237
+ " <td>0</td>\n",
238
+ " </tr>\n",
239
+ " <tr>\n",
240
+ " <th>hour_14</th>\n",
241
+ " <td>0</td>\n",
242
+ " <td>0</td>\n",
243
+ " <td>0</td>\n",
244
+ " </tr>\n",
245
+ " <tr>\n",
246
+ " <th>hour_15</th>\n",
247
+ " <td>0</td>\n",
248
+ " <td>0</td>\n",
249
+ " <td>0</td>\n",
250
+ " </tr>\n",
251
+ " <tr>\n",
252
+ " <th>hour_16</th>\n",
253
+ " <td>0</td>\n",
254
+ " <td>0</td>\n",
255
+ " <td>0</td>\n",
256
+ " </tr>\n",
257
+ " <tr>\n",
258
+ " <th>hour_17</th>\n",
259
+ " <td>0</td>\n",
260
+ " <td>0</td>\n",
261
+ " <td>0</td>\n",
262
+ " </tr>\n",
263
+ " <tr>\n",
264
+ " <th>hour_18</th>\n",
265
+ " <td>0</td>\n",
266
+ " <td>0</td>\n",
267
+ " <td>0</td>\n",
268
+ " </tr>\n",
269
+ " <tr>\n",
270
+ " <th>hour_19</th>\n",
271
+ " <td>1</td>\n",
272
+ " <td>0</td>\n",
273
+ " <td>1</td>\n",
274
+ " </tr>\n",
275
+ " <tr>\n",
276
+ " <th>hour_20</th>\n",
277
+ " <td>0</td>\n",
278
+ " <td>0</td>\n",
279
+ " <td>0</td>\n",
280
+ " </tr>\n",
281
+ " <tr>\n",
282
+ " <th>hour_21</th>\n",
283
+ " <td>0</td>\n",
284
+ " <td>0</td>\n",
285
+ " <td>0</td>\n",
286
+ " </tr>\n",
287
+ " <tr>\n",
288
+ " <th>hour_22</th>\n",
289
+ " <td>0</td>\n",
290
+ " <td>0</td>\n",
291
+ " <td>0</td>\n",
292
+ " </tr>\n",
293
+ " <tr>\n",
294
+ " <th>hour_23</th>\n",
295
+ " <td>0</td>\n",
296
+ " <td>0</td>\n",
297
+ " <td>0</td>\n",
298
+ " </tr>\n",
299
+ " <tr>\n",
300
+ " <th>type_ATROPELAMENTO</th>\n",
301
+ " <td>0</td>\n",
302
+ " <td>0</td>\n",
303
+ " <td>1</td>\n",
304
+ " </tr>\n",
305
+ " <tr>\n",
306
+ " <th>type_CHOQUE</th>\n",
307
+ " <td>0</td>\n",
308
+ " <td>0</td>\n",
309
+ " <td>0</td>\n",
310
+ " </tr>\n",
311
+ " <tr>\n",
312
+ " <th>type_COLISÃO</th>\n",
313
+ " <td>0</td>\n",
314
+ " <td>0</td>\n",
315
+ " <td>0</td>\n",
316
+ " </tr>\n",
317
+ " <tr>\n",
318
+ " <th>type_OUTROS</th>\n",
319
+ " <td>0</td>\n",
320
+ " <td>0</td>\n",
321
+ " <td>0</td>\n",
322
+ " </tr>\n",
323
+ " </tbody>\n",
324
+ "</table>\n",
325
+ "</div>"
326
+ ],
327
+ "text/plain": [
328
+ " 0 1 2\n",
329
+ "latitude -30.009614 -30.0403 -30.069\n",
330
+ "longitude -51.185581 -51.1958 -51.1437\n",
331
+ "feridos True True True\n",
332
+ "feridos_gr False False False\n",
333
+ "fatais False False False\n",
334
+ "caminhao False False False\n",
335
+ "moto True True False\n",
336
+ "cars True True True\n",
337
+ "transport False False False\n",
338
+ "others False False False\n",
339
+ "holiday False True True\n",
340
+ "day_1 0 0 0\n",
341
+ "day_2 0 0 0\n",
342
+ "day_3 0 0 0\n",
343
+ "day_4 0 0 0\n",
344
+ "day_5 1 0 0\n",
345
+ "day_6 0 1 1\n",
346
+ "hour_1 0 0 0\n",
347
+ "hour_2 0 0 0\n",
348
+ "hour_3 0 0 0\n",
349
+ "hour_4 0 0 0\n",
350
+ "hour_5 0 0 0\n",
351
+ "hour_6 0 0 0\n",
352
+ "hour_7 0 0 0\n",
353
+ "hour_8 0 0 0\n",
354
+ "hour_9 0 0 0\n",
355
+ "hour_10 0 1 0\n",
356
+ "hour_11 0 0 0\n",
357
+ "hour_12 0 0 0\n",
358
+ "hour_13 0 0 0\n",
359
+ "hour_14 0 0 0\n",
360
+ "hour_15 0 0 0\n",
361
+ "hour_16 0 0 0\n",
362
+ "hour_17 0 0 0\n",
363
+ "hour_18 0 0 0\n",
364
+ "hour_19 1 0 1\n",
365
+ "hour_20 0 0 0\n",
366
+ "hour_21 0 0 0\n",
367
+ "hour_22 0 0 0\n",
368
+ "hour_23 0 0 0\n",
369
+ "type_ATROPELAMENTO 0 0 1\n",
370
+ "type_CHOQUE 0 0 0\n",
371
+ "type_COLISÃO 0 0 0\n",
372
+ "type_OUTROS 0 0 0"
373
+ ]
374
+ },
375
+ "execution_count": 1,
376
+ "metadata": {},
377
+ "output_type": "execute_result"
378
+ }
379
+ ],
380
+ "source": [
381
+ "import os.path as path\n",
382
+ "from pandas import read_csv\n",
383
+ "\n",
384
+ "file_csv = path.abspath(\"../\")\n",
385
+ "\n",
386
+ "file_csv = path.join(file_csv, \"data\" ,\"accidents_trans.csv\")\n",
387
+ "\n",
388
+ "accidents_trans = read_csv(file_csv)\n",
389
+ "\n",
390
+ "accidents_trans.head(3).T"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "markdown",
395
+ "metadata": {},
396
+ "source": [
397
+ "# 3. Data Preparation"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": 2,
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "import joblib as jb # Use to save the model to deploy\n",
407
+ "from sklearn.preprocessing import StandardScaler\n",
408
+ "from sklearn.model_selection import train_test_split"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": 3,
414
+ "metadata": {},
415
+ "outputs": [
416
+ {
417
+ "name": "stdout",
418
+ "output_type": "stream",
419
+ "text": [
420
+ "Our model to predict the probability of feridos will be create with 68218 rows and 41 features.\n"
421
+ ]
422
+ }
423
+ ],
424
+ "source": [
425
+ "outputs = [\"feridos\", \"feridos_gr\", \"fatais\"]\n",
426
+ "inputs = [col for col in accidents_trans.columns if col not in outputs]\n",
427
+ "\n",
428
+ "X = accidents_trans[inputs].copy()\n",
429
+ "Y = accidents_trans[outputs].copy()\n",
430
+ "\n",
431
+ "# Filtering data considering the output\n",
432
+ "output = \"feridos\"\n",
433
+ "\n",
434
+ "if output == \"feridos_gr\":\n",
435
+ " X = X[Y[\"feridos\"]]\n",
436
+ " Y = Y.loc[Y[\"feridos\"], \"feridos_gr\"]\n",
437
+ "elif output == \"fatais\":\n",
438
+ " X = X[Y[\"feridos_gr\"]]\n",
439
+ " Y = Y.loc[Y[\"feridos_gr\"], \"fatais\"]\n",
440
+ "else:\n",
441
+ " Y = Y[\"feridos\"]\n",
442
+ "\n",
443
+ "print(f\"Our model to predict the probability of \" \\\n",
444
+ " f\"{output} will be create with {X.shape[0]} \" \\\n",
445
+ " f\"rows and {X.shape[1]} features.\")"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "code",
450
+ "execution_count": 4,
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "import csv\n",
455
+ "\n",
456
+ "with open(\"model_features.csv\", 'w') as f:\n",
457
+ " writer = csv.writer(f)\n",
458
+ " writer.writerow(X.columns)"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "markdown",
463
+ "metadata": {},
464
+ "source": [
465
+ "Considering that we will use models scaling sensitive, we will need to scale our data first. Beside this, we will need to save our scaler for future use."
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "code",
470
+ "execution_count": 5,
471
+ "metadata": {},
472
+ "outputs": [
473
+ {
474
+ "data": {
475
+ "text/plain": [
476
+ "['c:\\\\Users\\\\grego\\\\OneDrive\\\\Documentos\\\\Documentos Pessoais\\\\00_DataCamp\\\\09_VSC\\\\poa_car_accidents\\\\poa_car_accidents\\\\model\\\\scaler_feridos.pkl']"
477
+ ]
478
+ },
479
+ "execution_count": 5,
480
+ "metadata": {},
481
+ "output_type": "execute_result"
482
+ }
483
+ ],
484
+ "source": [
485
+ "# Setting the random state using my luck number :-)\n",
486
+ "lucky_num = 7\n",
487
+ "\n",
488
+ "# X_train and y_train to train our model\n",
489
+ "X_train, X_test, y_train, y_test = train_test_split(\n",
490
+ " X,\n",
491
+ " Y,\n",
492
+ " test_size=0.30,\n",
493
+ " random_state=lucky_num,\n",
494
+ " shuffle=True, # Used because our data is sort by date\n",
495
+ " stratify=Y) # Used because our data is unbalanced\n",
496
+ "\n",
497
+ "# Scaling\n",
498
+ "scaler = StandardScaler()\n",
499
+ "X_train = scaler.fit_transform(X_train)\n",
500
+ "X_test = scaler.transform(X_test)\n",
501
+ "\n",
502
+ "# Saving scaler\n",
503
+ "file_name = \"scaler_\" + output + '.pkl'\n",
504
+ "jb.dump(scaler, path.join(path.abspath(\"./\"), file_name))"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "markdown",
509
+ "metadata": {},
510
+ "source": [
511
+ "# 4. Data Modeling\n",
512
+ "\n",
513
+ "We will create and use cross-validation to evaluate the following models:\n",
514
+ "\n",
515
+ "- Logistic Regression;\n",
516
+ "\n",
517
+ "- Gaussian Naive Bayes;\n",
518
+ "\n",
519
+ "- K Neighbors;\n",
520
+ "\n",
521
+ "- Random Forest;\n",
522
+ "\n",
523
+ "- Gradient Boosting; and,\n",
524
+ "\n",
525
+ "- XGBoost.\n",
526
+ "\n",
527
+ "We will use two scores to select and evaluate our models:\n",
528
+ "\n",
529
+ "- F1 score: composition between the precision (how much our model correct classify every true label) and recall (how moch our model correct indicate true labels); and,\n",
530
+ "\n",
531
+ "- Brier score: average between the correct and the predict probability.\n",
532
+ "\n",
533
+ "However, we will see other metrics to support our decision:\n",
534
+ "\n",
535
+ "- Accurancy;\n",
536
+ "\n",
537
+ "- ROC_AOC; and,\n",
538
+ "\n",
539
+ "- Log loss (an other way to quantify the quality of probability predictions).\n",
540
+ "\n",
541
+ "And, before you go, we will find for each model if there is a hyperparameter to deal with the unbalanced output."
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": 6,
547
+ "metadata": {},
548
+ "outputs": [],
549
+ "source": [
550
+ "import pandas as pd\n",
551
+ "import xgboost as xgb\n",
552
+ "from sklearn.naive_bayes import GaussianNB\n",
553
+ "from sklearn.neighbors import KNeighborsClassifier\n",
554
+ "from sklearn.linear_model import LogisticRegression\n",
555
+ "from sklearn.model_selection import cross_validate \n",
556
+ "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier\n",
557
+ "from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_auc_score, f1_score, brier_score_loss, log_loss\n",
558
+ "\n",
559
+ "scores = [\"accuracy\", \"f1\", \"precision\", \"recall\", \"roc_auc\", \"neg_brier_score\",\"neg_log_loss\"]"
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "code",
564
+ "execution_count": 7,
565
+ "metadata": {},
566
+ "outputs": [],
567
+ "source": [
568
+ "def eval_model(cls) -> tuple:\n",
569
+ " \"\"\"This function will calculate the metrics\n",
570
+ " to evaluate a classification model.\n",
571
+ " \"\"\"\n",
572
+ " # Predicting labels and probabilities\n",
573
+ " y_pred = cls.predict(X_test)\n",
574
+ " y_prob = cls.predict_proba(X_test)[:,1]\n",
575
+ "\n",
576
+ " # Calculating scores\n",
577
+ " accurancy = accuracy_score(y_test, y_pred)\n",
578
+ " f1 = f1_score(y_test, y_pred)\n",
579
+ " recall = recall_score(y_test, y_pred)\n",
580
+ " precision = precision_score(y_test, y_pred)\n",
581
+ " roc_auc = roc_auc_score(y_test, y_prob) # https://datascience.stackexchange.com/questions/114394/does-roc-auc-different-between-crossval-and-test-set-indicate-overfitting-or-oth\n",
582
+ " brier_score = brier_score_loss(y_test, y_prob)\n",
583
+ " log_loss_value = log_loss(y_test, y_prob)\n",
584
+ "\n",
585
+ " return accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value\n",
586
+ "\n",
587
+ "def create_model(name: str, cls) -> list:\n",
588
+ " \"\"\"This function will create some models\n",
589
+ " and return scores to evaluate it.\"\"\"\n",
590
+ " # Ftting model\n",
591
+ " cls.fit(X_train, y_train)\n",
592
+ "\n",
593
+ " # Using cross-validation to evaluate the model fitted\n",
594
+ " cls_cross = cross_validate(\n",
595
+ " estimator=cls,\n",
596
+ " X=X_train,\n",
597
+ " y=y_train,\n",
598
+ " cv=5,\n",
599
+ " scoring=scores)\n",
600
+ "\n",
601
+ " df_cv = pd.DataFrame.from_dict(cls_cross, orient='index', columns=[\"CV\"+str(i) for i in range(1,6)])\n",
602
+ "\n",
603
+ " # Calculating score to test set\n",
604
+ " accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value = eval_model(cls)\n",
605
+ "\n",
606
+ " # Filling a dataframe to better presentation\n",
607
+ " df_cv.at[\"test_accuracy\", \"TestSet\"] = accurancy\n",
608
+ " df_cv.at[\"test_f1\", \"TestSet\"] = f1\n",
609
+ " df_cv.at[\"test_recall\", \"TestSet\"] = recall\n",
610
+ " df_cv.at[\"test_precision\", \"TestSet\"] = precision\n",
611
+ " df_cv.at[\"test_roc_auc\", \"TestSet\"] = roc_auc\n",
612
+ " df_cv.at[\"test_neg_brier_score\", \"TestSet\"] = -brier_score\n",
613
+ " df_cv.at[\"test_neg_log_loss\", \"TestSet\"] = -log_loss_value\n",
614
+ "\n",
615
+ " caption = f\"{name} Validation Scores\"\n",
616
+ "\n",
617
+ " display(df_cv.style.set_caption(caption))\n",
618
+ "\n",
619
+ " return [accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value]"
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "code",
624
+ "execution_count": 8,
625
+ "metadata": {},
626
+ "outputs": [
627
+ {
628
+ "data": {
629
+ "text/html": [
630
+ "<style type=\"text/css\">\n",
631
+ "</style>\n",
632
+ "<table id=\"T_31097\">\n",
633
+ " <caption>LR Validation Scores</caption>\n",
634
+ " <thead>\n",
635
+ " <tr>\n",
636
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
637
+ " <th id=\"T_31097_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
638
+ " <th id=\"T_31097_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
639
+ " <th id=\"T_31097_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
640
+ " <th id=\"T_31097_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
641
+ " <th id=\"T_31097_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
642
+ " <th id=\"T_31097_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
643
+ " </tr>\n",
644
+ " </thead>\n",
645
+ " <tbody>\n",
646
+ " <tr>\n",
647
+ " <th id=\"T_31097_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
648
+ " <td id=\"T_31097_row0_col0\" class=\"data row0 col0\" >0.082354</td>\n",
649
+ " <td id=\"T_31097_row0_col1\" class=\"data row0 col1\" >0.080257</td>\n",
650
+ " <td id=\"T_31097_row0_col2\" class=\"data row0 col2\" >0.089329</td>\n",
651
+ " <td id=\"T_31097_row0_col3\" class=\"data row0 col3\" >0.094720</td>\n",
652
+ " <td id=\"T_31097_row0_col4\" class=\"data row0 col4\" >0.087742</td>\n",
653
+ " <td id=\"T_31097_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
654
+ " </tr>\n",
655
+ " <tr>\n",
656
+ " <th id=\"T_31097_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
657
+ " <td id=\"T_31097_row1_col0\" class=\"data row1 col0\" >0.016066</td>\n",
658
+ " <td id=\"T_31097_row1_col1\" class=\"data row1 col1\" >0.017635</td>\n",
659
+ " <td id=\"T_31097_row1_col2\" class=\"data row1 col2\" >0.020100</td>\n",
660
+ " <td id=\"T_31097_row1_col3\" class=\"data row1 col3\" >0.018260</td>\n",
661
+ " <td id=\"T_31097_row1_col4\" class=\"data row1 col4\" >0.018356</td>\n",
662
+ " <td id=\"T_31097_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
663
+ " </tr>\n",
664
+ " <tr>\n",
665
+ " <th id=\"T_31097_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
666
+ " <td id=\"T_31097_row2_col0\" class=\"data row2 col0\" >0.869228</td>\n",
667
+ " <td id=\"T_31097_row2_col1\" class=\"data row2 col1\" >0.868391</td>\n",
668
+ " <td id=\"T_31097_row2_col2\" class=\"data row2 col2\" >0.872356</td>\n",
669
+ " <td id=\"T_31097_row2_col3\" class=\"data row2 col3\" >0.869005</td>\n",
670
+ " <td id=\"T_31097_row2_col4\" class=\"data row2 col4\" >0.867539</td>\n",
671
+ " <td id=\"T_31097_row2_col5\" class=\"data row2 col5\" >0.865924</td>\n",
672
+ " </tr>\n",
673
+ " <tr>\n",
674
+ " <th id=\"T_31097_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
675
+ " <td id=\"T_31097_row3_col0\" class=\"data row3 col0\" >0.817584</td>\n",
676
+ " <td id=\"T_31097_row3_col1\" class=\"data row3 col1\" >0.818116</td>\n",
677
+ " <td id=\"T_31097_row3_col2\" class=\"data row3 col2\" >0.823920</td>\n",
678
+ " <td id=\"T_31097_row3_col3\" class=\"data row3 col3\" >0.819611</td>\n",
679
+ " <td id=\"T_31097_row3_col4\" class=\"data row3 col4\" >0.817011</td>\n",
680
+ " <td id=\"T_31097_row3_col5\" class=\"data row3 col5\" >0.814469</td>\n",
681
+ " </tr>\n",
682
+ " <tr>\n",
683
+ " <th id=\"T_31097_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
684
+ " <td id=\"T_31097_row4_col0\" class=\"data row4 col0\" >0.854135</td>\n",
685
+ " <td id=\"T_31097_row4_col1\" class=\"data row4 col1\" >0.846154</td>\n",
686
+ " <td id=\"T_31097_row4_col2\" class=\"data row4 col2\" >0.850582</td>\n",
687
+ " <td id=\"T_31097_row4_col3\" class=\"data row4 col3\" >0.844326</td>\n",
688
+ " <td id=\"T_31097_row4_col4\" class=\"data row4 col4\" >0.844498</td>\n",
689
+ " <td id=\"T_31097_row4_col5\" class=\"data row4 col5\" >0.843439</td>\n",
690
+ " </tr>\n",
691
+ " <tr>\n",
692
+ " <th id=\"T_31097_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
693
+ " <td id=\"T_31097_row5_col0\" class=\"data row5 col0\" >0.784034</td>\n",
694
+ " <td id=\"T_31097_row5_col1\" class=\"data row5 col1\" >0.791877</td>\n",
695
+ " <td id=\"T_31097_row5_col2\" class=\"data row5 col2\" >0.798880</td>\n",
696
+ " <td id=\"T_31097_row5_col3\" class=\"data row5 col3\" >0.796301</td>\n",
697
+ " <td id=\"T_31097_row5_col4\" class=\"data row5 col4\" >0.791258</td>\n",
698
+ " <td id=\"T_31097_row5_col5\" class=\"data row5 col5\" >0.787423</td>\n",
699
+ " </tr>\n",
700
+ " <tr>\n",
701
+ " <th id=\"T_31097_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
702
+ " <td id=\"T_31097_row6_col0\" class=\"data row6 col0\" >0.903418</td>\n",
703
+ " <td id=\"T_31097_row6_col1\" class=\"data row6 col1\" >0.904970</td>\n",
704
+ " <td id=\"T_31097_row6_col2\" class=\"data row6 col2\" >0.906377</td>\n",
705
+ " <td id=\"T_31097_row6_col3\" class=\"data row6 col3\" >0.902405</td>\n",
706
+ " <td id=\"T_31097_row6_col4\" class=\"data row6 col4\" >0.906939</td>\n",
707
+ " <td id=\"T_31097_row6_col5\" class=\"data row6 col5\" >0.904458</td>\n",
708
+ " </tr>\n",
709
+ " <tr>\n",
710
+ " <th id=\"T_31097_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
711
+ " <td id=\"T_31097_row7_col0\" class=\"data row7 col0\" >-0.109808</td>\n",
712
+ " <td id=\"T_31097_row7_col1\" class=\"data row7 col1\" >-0.109221</td>\n",
713
+ " <td id=\"T_31097_row7_col2\" class=\"data row7 col2\" >-0.106382</td>\n",
714
+ " <td id=\"T_31097_row7_col3\" class=\"data row7 col3\" >-0.110939</td>\n",
715
+ " <td id=\"T_31097_row7_col4\" class=\"data row7 col4\" >-0.109709</td>\n",
716
+ " <td id=\"T_31097_row7_col5\" class=\"data row7 col5\" >-0.110435</td>\n",
717
+ " </tr>\n",
718
+ " <tr>\n",
719
+ " <th id=\"T_31097_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
720
+ " <td id=\"T_31097_row8_col0\" class=\"data row8 col0\" >-0.370200</td>\n",
721
+ " <td id=\"T_31097_row8_col1\" class=\"data row8 col1\" >-0.366684</td>\n",
722
+ " <td id=\"T_31097_row8_col2\" class=\"data row8 col2\" >-0.360534</td>\n",
723
+ " <td id=\"T_31097_row8_col3\" class=\"data row8 col3\" >-0.372374</td>\n",
724
+ " <td id=\"T_31097_row8_col4\" class=\"data row8 col4\" >-0.367532</td>\n",
725
+ " <td id=\"T_31097_row8_col5\" class=\"data row8 col5\" >-0.370350</td>\n",
726
+ " </tr>\n",
727
+ " </tbody>\n",
728
+ "</table>\n"
729
+ ],
730
+ "text/plain": [
731
+ "<pandas.io.formats.style.Styler at 0x1a1a6427220>"
732
+ ]
733
+ },
734
+ "metadata": {},
735
+ "output_type": "display_data"
736
+ },
737
+ {
738
+ "data": {
739
+ "text/html": [
740
+ "<style type=\"text/css\">\n",
741
+ "</style>\n",
742
+ "<table id=\"T_750d2\">\n",
743
+ " <caption>NB Validation Scores</caption>\n",
744
+ " <thead>\n",
745
+ " <tr>\n",
746
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
747
+ " <th id=\"T_750d2_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
748
+ " <th id=\"T_750d2_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
749
+ " <th id=\"T_750d2_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
750
+ " <th id=\"T_750d2_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
751
+ " <th id=\"T_750d2_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
752
+ " <th id=\"T_750d2_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
753
+ " </tr>\n",
754
+ " </thead>\n",
755
+ " <tbody>\n",
756
+ " <tr>\n",
757
+ " <th id=\"T_750d2_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
758
+ " <td id=\"T_750d2_row0_col0\" class=\"data row0 col0\" >0.035410</td>\n",
759
+ " <td id=\"T_750d2_row0_col1\" class=\"data row0 col1\" >0.030015</td>\n",
760
+ " <td id=\"T_750d2_row0_col2\" class=\"data row0 col2\" >0.032639</td>\n",
761
+ " <td id=\"T_750d2_row0_col3\" class=\"data row0 col3\" >0.029752</td>\n",
762
+ " <td id=\"T_750d2_row0_col4\" class=\"data row0 col4\" >0.030653</td>\n",
763
+ " <td id=\"T_750d2_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
764
+ " </tr>\n",
765
+ " <tr>\n",
766
+ " <th id=\"T_750d2_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
767
+ " <td id=\"T_750d2_row1_col0\" class=\"data row1 col0\" >0.037826</td>\n",
768
+ " <td id=\"T_750d2_row1_col1\" class=\"data row1 col1\" >0.040993</td>\n",
769
+ " <td id=\"T_750d2_row1_col2\" class=\"data row1 col2\" >0.032376</td>\n",
770
+ " <td id=\"T_750d2_row1_col3\" class=\"data row1 col3\" >0.030767</td>\n",
771
+ " <td id=\"T_750d2_row1_col4\" class=\"data row1 col4\" >0.028092</td>\n",
772
+ " <td id=\"T_750d2_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
773
+ " </tr>\n",
774
+ " <tr>\n",
775
+ " <th id=\"T_750d2_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
776
+ " <td id=\"T_750d2_row2_col0\" class=\"data row2 col0\" >0.768401</td>\n",
777
+ " <td id=\"T_750d2_row2_col1\" class=\"data row2 col1\" >0.763376</td>\n",
778
+ " <td id=\"T_750d2_row2_col2\" class=\"data row2 col2\" >0.765131</td>\n",
779
+ " <td id=\"T_750d2_row2_col3\" class=\"data row2 col3\" >0.771518</td>\n",
780
+ " <td id=\"T_750d2_row2_col4\" class=\"data row2 col4\" >0.772251</td>\n",
781
+ " <td id=\"T_750d2_row2_col5\" class=\"data row2 col5\" >0.766637</td>\n",
782
+ " </tr>\n",
783
+ " <tr>\n",
784
+ " <th id=\"T_750d2_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
785
+ " <td id=\"T_750d2_row3_col0\" class=\"data row3 col0\" >0.667068</td>\n",
786
+ " <td id=\"T_750d2_row3_col1\" class=\"data row3 col1\" >0.654223</td>\n",
787
+ " <td id=\"T_750d2_row3_col2\" class=\"data row3 col2\" >0.660922</td>\n",
788
+ " <td id=\"T_750d2_row3_col3\" class=\"data row3 col3\" >0.675876</td>\n",
789
+ " <td id=\"T_750d2_row3_col4\" class=\"data row3 col4\" >0.668899</td>\n",
790
+ " <td id=\"T_750d2_row3_col5\" class=\"data row3 col5\" >0.664795</td>\n",
791
+ " </tr>\n",
792
+ " <tr>\n",
793
+ " <th id=\"T_750d2_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
794
+ " <td id=\"T_750d2_row4_col0\" class=\"data row4 col0\" >0.720885</td>\n",
795
+ " <td id=\"T_750d2_row4_col1\" class=\"data row4 col1\" >0.720836</td>\n",
796
+ " <td id=\"T_750d2_row4_col2\" class=\"data row4 col2\" >0.717898</td>\n",
797
+ " <td id=\"T_750d2_row4_col3\" class=\"data row4 col3\" >0.719254</td>\n",
798
+ " <td id=\"T_750d2_row4_col4\" class=\"data row4 col4\" >0.732333</td>\n",
799
+ " <td id=\"T_750d2_row4_col5\" class=\"data row4 col5\" >0.717684</td>\n",
800
+ " </tr>\n",
801
+ " <tr>\n",
802
+ " <th id=\"T_750d2_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
803
+ " <td id=\"T_750d2_row5_col0\" class=\"data row5 col0\" >0.620728</td>\n",
804
+ " <td id=\"T_750d2_row5_col1\" class=\"data row5 col1\" >0.598880</td>\n",
805
+ " <td id=\"T_750d2_row5_col2\" class=\"data row5 col2\" >0.612325</td>\n",
806
+ " <td id=\"T_750d2_row5_col3\" class=\"data row5 col3\" >0.637433</td>\n",
807
+ " <td id=\"T_750d2_row5_col4\" class=\"data row5 col4\" >0.615579</td>\n",
808
+ " <td id=\"T_750d2_row5_col5\" class=\"data row5 col5\" >0.619166</td>\n",
809
+ " </tr>\n",
810
+ " <tr>\n",
811
+ " <th id=\"T_750d2_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
812
+ " <td id=\"T_750d2_row6_col0\" class=\"data row6 col0\" >0.852290</td>\n",
813
+ " <td id=\"T_750d2_row6_col1\" class=\"data row6 col1\" >0.847184</td>\n",
814
+ " <td id=\"T_750d2_row6_col2\" class=\"data row6 col2\" >0.843733</td>\n",
815
+ " <td id=\"T_750d2_row6_col3\" class=\"data row6 col3\" >0.851873</td>\n",
816
+ " <td id=\"T_750d2_row6_col4\" class=\"data row6 col4\" >0.856047</td>\n",
817
+ " <td id=\"T_750d2_row6_col5\" class=\"data row6 col5\" >0.848834</td>\n",
818
+ " </tr>\n",
819
+ " <tr>\n",
820
+ " <th id=\"T_750d2_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
821
+ " <td id=\"T_750d2_row7_col0\" class=\"data row7 col0\" >-0.206596</td>\n",
822
+ " <td id=\"T_750d2_row7_col1\" class=\"data row7 col1\" >-0.210362</td>\n",
823
+ " <td id=\"T_750d2_row7_col2\" class=\"data row7 col2\" >-0.211968</td>\n",
824
+ " <td id=\"T_750d2_row7_col3\" class=\"data row7 col3\" >-0.204214</td>\n",
825
+ " <td id=\"T_750d2_row7_col4\" class=\"data row7 col4\" >-0.202682</td>\n",
826
+ " <td id=\"T_750d2_row7_col5\" class=\"data row7 col5\" >-0.208278</td>\n",
827
+ " </tr>\n",
828
+ " <tr>\n",
829
+ " <th id=\"T_750d2_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
830
+ " <td id=\"T_750d2_row8_col0\" class=\"data row8 col0\" >-1.668014</td>\n",
831
+ " <td id=\"T_750d2_row8_col1\" class=\"data row8 col1\" >-1.788896</td>\n",
832
+ " <td id=\"T_750d2_row8_col2\" class=\"data row8 col2\" >-1.917438</td>\n",
833
+ " <td id=\"T_750d2_row8_col3\" class=\"data row8 col3\" >-1.662381</td>\n",
834
+ " <td id=\"T_750d2_row8_col4\" class=\"data row8 col4\" >-1.670358</td>\n",
835
+ " <td id=\"T_750d2_row8_col5\" class=\"data row8 col5\" >-1.761326</td>\n",
836
+ " </tr>\n",
837
+ " </tbody>\n",
838
+ "</table>\n"
839
+ ],
840
+ "text/plain": [
841
+ "<pandas.io.formats.style.Styler at 0x1a1a6bb4820>"
842
+ ]
843
+ },
844
+ "metadata": {},
845
+ "output_type": "display_data"
846
+ },
847
+ {
848
+ "data": {
849
+ "text/html": [
850
+ "<style type=\"text/css\">\n",
851
+ "</style>\n",
852
+ "<table id=\"T_2d5a3\">\n",
853
+ " <caption>KNN Validation Scores</caption>\n",
854
+ " <thead>\n",
855
+ " <tr>\n",
856
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
857
+ " <th id=\"T_2d5a3_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
858
+ " <th id=\"T_2d5a3_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
859
+ " <th id=\"T_2d5a3_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
860
+ " <th id=\"T_2d5a3_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
861
+ " <th id=\"T_2d5a3_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
862
+ " <th id=\"T_2d5a3_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
863
+ " </tr>\n",
864
+ " </thead>\n",
865
+ " <tbody>\n",
866
+ " <tr>\n",
867
+ " <th id=\"T_2d5a3_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
868
+ " <td id=\"T_2d5a3_row0_col0\" class=\"data row0 col0\" >0.010002</td>\n",
869
+ " <td id=\"T_2d5a3_row0_col1\" class=\"data row0 col1\" >0.011312</td>\n",
870
+ " <td id=\"T_2d5a3_row0_col2\" class=\"data row0 col2\" >0.011621</td>\n",
871
+ " <td id=\"T_2d5a3_row0_col3\" class=\"data row0 col3\" >0.013843</td>\n",
872
+ " <td id=\"T_2d5a3_row0_col4\" class=\"data row0 col4\" >0.011473</td>\n",
873
+ " <td id=\"T_2d5a3_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
874
+ " </tr>\n",
875
+ " <tr>\n",
876
+ " <th id=\"T_2d5a3_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
877
+ " <td id=\"T_2d5a3_row1_col0\" class=\"data row1 col0\" >1.660269</td>\n",
878
+ " <td id=\"T_2d5a3_row1_col1\" class=\"data row1 col1\" >1.360570</td>\n",
879
+ " <td id=\"T_2d5a3_row1_col2\" class=\"data row1 col2\" >1.651296</td>\n",
880
+ " <td id=\"T_2d5a3_row1_col3\" class=\"data row1 col3\" >1.734129</td>\n",
881
+ " <td id=\"T_2d5a3_row1_col4\" class=\"data row1 col4\" >1.823339</td>\n",
882
+ " <td id=\"T_2d5a3_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
883
+ " </tr>\n",
884
+ " <tr>\n",
885
+ " <th id=\"T_2d5a3_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
886
+ " <td id=\"T_2d5a3_row2_col0\" class=\"data row2 col0\" >0.842320</td>\n",
887
+ " <td id=\"T_2d5a3_row2_col1\" class=\"data row2 col1\" >0.848707</td>\n",
888
+ " <td id=\"T_2d5a3_row2_col2\" class=\"data row2 col2\" >0.847330</td>\n",
889
+ " <td id=\"T_2d5a3_row2_col3\" class=\"data row2 col3\" >0.842723</td>\n",
890
+ " <td id=\"T_2d5a3_row2_col4\" class=\"data row2 col4\" >0.847749</td>\n",
891
+ " <td id=\"T_2d5a3_row2_col5\" class=\"data row2 col5\" >0.843692</td>\n",
892
+ " </tr>\n",
893
+ " <tr>\n",
894
+ " <th id=\"T_2d5a3_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
895
+ " <td id=\"T_2d5a3_row3_col0\" class=\"data row3 col0\" >0.776492</td>\n",
896
+ " <td id=\"T_2d5a3_row3_col1\" class=\"data row3 col1\" >0.787218</td>\n",
897
+ " <td id=\"T_2d5a3_row3_col2\" class=\"data row3 col2\" >0.783551</td>\n",
898
+ " <td id=\"T_2d5a3_row3_col3\" class=\"data row3 col3\" >0.779053</td>\n",
899
+ " <td id=\"T_2d5a3_row3_col4\" class=\"data row3 col4\" >0.786365</td>\n",
900
+ " <td id=\"T_2d5a3_row3_col5\" class=\"data row3 col5\" >0.779698</td>\n",
901
+ " </tr>\n",
902
+ " <tr>\n",
903
+ " <th id=\"T_2d5a3_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
904
+ " <td id=\"T_2d5a3_row4_col0\" class=\"data row4 col0\" >0.825758</td>\n",
905
+ " <td id=\"T_2d5a3_row4_col1\" class=\"data row4 col1\" >0.829867</td>\n",
906
+ " <td id=\"T_2d5a3_row4_col2\" class=\"data row4 col2\" >0.833544</td>\n",
907
+ " <td id=\"T_2d5a3_row4_col3\" class=\"data row4 col3\" >0.820068</td>\n",
908
+ " <td id=\"T_2d5a3_row4_col4\" class=\"data row4 col4\" >0.826691</td>\n",
909
+ " <td id=\"T_2d5a3_row4_col5\" class=\"data row4 col5\" >0.823778</td>\n",
910
+ " </tr>\n",
911
+ " <tr>\n",
912
+ " <th id=\"T_2d5a3_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
913
+ " <td id=\"T_2d5a3_row5_col0\" class=\"data row5 col0\" >0.732773</td>\n",
914
+ " <td id=\"T_2d5a3_row5_col1\" class=\"data row5 col1\" >0.748739</td>\n",
915
+ " <td id=\"T_2d5a3_row5_col2\" class=\"data row5 col2\" >0.739216</td>\n",
916
+ " <td id=\"T_2d5a3_row5_col3\" class=\"data row5 col3\" >0.741945</td>\n",
917
+ " <td id=\"T_2d5a3_row5_col4\" class=\"data row5 col4\" >0.749790</td>\n",
918
+ " <td id=\"T_2d5a3_row5_col5\" class=\"data row5 col5\" >0.740097</td>\n",
919
+ " </tr>\n",
920
+ " <tr>\n",
921
+ " <th id=\"T_2d5a3_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
922
+ " <td id=\"T_2d5a3_row6_col0\" class=\"data row6 col0\" >0.867330</td>\n",
923
+ " <td id=\"T_2d5a3_row6_col1\" class=\"data row6 col1\" >0.869924</td>\n",
924
+ " <td id=\"T_2d5a3_row6_col2\" class=\"data row6 col2\" >0.872951</td>\n",
925
+ " <td id=\"T_2d5a3_row6_col3\" class=\"data row6 col3\" >0.866868</td>\n",
926
+ " <td id=\"T_2d5a3_row6_col4\" class=\"data row6 col4\" >0.872277</td>\n",
927
+ " <td id=\"T_2d5a3_row6_col5\" class=\"data row6 col5\" >0.872155</td>\n",
928
+ " </tr>\n",
929
+ " <tr>\n",
930
+ " <th id=\"T_2d5a3_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
931
+ " <td id=\"T_2d5a3_row7_col0\" class=\"data row7 col0\" >-0.130989</td>\n",
932
+ " <td id=\"T_2d5a3_row7_col1\" class=\"data row7 col1\" >-0.127425</td>\n",
933
+ " <td id=\"T_2d5a3_row7_col2\" class=\"data row7 col2\" >-0.126777</td>\n",
934
+ " <td id=\"T_2d5a3_row7_col3\" class=\"data row7 col3\" >-0.130655</td>\n",
935
+ " <td id=\"T_2d5a3_row7_col4\" class=\"data row7 col4\" >-0.127401</td>\n",
936
+ " <td id=\"T_2d5a3_row7_col5\" class=\"data row7 col5\" >-0.128215</td>\n",
937
+ " </tr>\n",
938
+ " <tr>\n",
939
+ " <th id=\"T_2d5a3_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
940
+ " <td id=\"T_2d5a3_row8_col0\" class=\"data row8 col0\" >-2.083997</td>\n",
941
+ " <td id=\"T_2d5a3_row8_col1\" class=\"data row8 col1\" >-1.959589</td>\n",
942
+ " <td id=\"T_2d5a3_row8_col2\" class=\"data row8 col2\" >-1.815403</td>\n",
943
+ " <td id=\"T_2d5a3_row8_col3\" class=\"data row8 col3\" >-2.007178</td>\n",
944
+ " <td id=\"T_2d5a3_row8_col4\" class=\"data row8 col4\" >-1.929602</td>\n",
945
+ " <td id=\"T_2d5a3_row8_col5\" class=\"data row8 col5\" >-1.877810</td>\n",
946
+ " </tr>\n",
947
+ " </tbody>\n",
948
+ "</table>\n"
949
+ ],
950
+ "text/plain": [
951
+ "<pandas.io.formats.style.Styler at 0x1a1a6bb4820>"
952
+ ]
953
+ },
954
+ "metadata": {},
955
+ "output_type": "display_data"
956
+ },
957
+ {
958
+ "data": {
959
+ "text/html": [
960
+ "<style type=\"text/css\">\n",
961
+ "</style>\n",
962
+ "<table id=\"T_08606\">\n",
963
+ " <caption>RF Validation Scores</caption>\n",
964
+ " <thead>\n",
965
+ " <tr>\n",
966
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
967
+ " <th id=\"T_08606_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
968
+ " <th id=\"T_08606_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
969
+ " <th id=\"T_08606_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
970
+ " <th id=\"T_08606_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
971
+ " <th id=\"T_08606_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
972
+ " <th id=\"T_08606_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
973
+ " </tr>\n",
974
+ " </thead>\n",
975
+ " <tbody>\n",
976
+ " <tr>\n",
977
+ " <th id=\"T_08606_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
978
+ " <td id=\"T_08606_row0_col0\" class=\"data row0 col0\" >4.099665</td>\n",
979
+ " <td id=\"T_08606_row0_col1\" class=\"data row0 col1\" >4.061200</td>\n",
980
+ " <td id=\"T_08606_row0_col2\" class=\"data row0 col2\" >4.090116</td>\n",
981
+ " <td id=\"T_08606_row0_col3\" class=\"data row0 col3\" >4.055705</td>\n",
982
+ " <td id=\"T_08606_row0_col4\" class=\"data row0 col4\" >4.050387</td>\n",
983
+ " <td id=\"T_08606_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
984
+ " </tr>\n",
985
+ " <tr>\n",
986
+ " <th id=\"T_08606_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
987
+ " <td id=\"T_08606_row1_col0\" class=\"data row1 col0\" >0.390365</td>\n",
988
+ " <td id=\"T_08606_row1_col1\" class=\"data row1 col1\" >0.389244</td>\n",
989
+ " <td id=\"T_08606_row1_col2\" class=\"data row1 col2\" >0.392108</td>\n",
990
+ " <td id=\"T_08606_row1_col3\" class=\"data row1 col3\" >0.387358</td>\n",
991
+ " <td id=\"T_08606_row1_col4\" class=\"data row1 col4\" >0.400155</td>\n",
992
+ " <td id=\"T_08606_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
993
+ " </tr>\n",
994
+ " <tr>\n",
995
+ " <th id=\"T_08606_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
996
+ " <td id=\"T_08606_row2_col0\" class=\"data row2 col0\" >0.856141</td>\n",
997
+ " <td id=\"T_08606_row2_col1\" class=\"data row2 col1\" >0.859282</td>\n",
998
+ " <td id=\"T_08606_row2_col2\" class=\"data row2 col2\" >0.861571</td>\n",
999
+ " <td id=\"T_08606_row2_col3\" class=\"data row2 col3\" >0.853508</td>\n",
1000
+ " <td id=\"T_08606_row2_col4\" class=\"data row2 col4\" >0.855393</td>\n",
1001
+ " <td id=\"T_08606_row2_col5\" class=\"data row2 col5\" >0.856152</td>\n",
1002
+ " </tr>\n",
1003
+ " <tr>\n",
1004
+ " <th id=\"T_08606_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
1005
+ " <td id=\"T_08606_row3_col0\" class=\"data row3 col0\" >0.800349</td>\n",
1006
+ " <td id=\"T_08606_row3_col1\" class=\"data row3 col1\" >0.805217</td>\n",
1007
+ " <td id=\"T_08606_row3_col2\" class=\"data row3 col2\" >0.807681</td>\n",
1008
+ " <td id=\"T_08606_row3_col3\" class=\"data row3 col3\" >0.798676</td>\n",
1009
+ " <td id=\"T_08606_row3_col4\" class=\"data row3 col4\" >0.799477</td>\n",
1010
+ " <td id=\"T_08606_row3_col5\" class=\"data row3 col5\" >0.800623</td>\n",
1011
+ " </tr>\n",
1012
+ " <tr>\n",
1013
+ " <th id=\"T_08606_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
1014
+ " <td id=\"T_08606_row4_col0\" class=\"data row4 col0\" >0.831522</td>\n",
1015
+ " <td id=\"T_08606_row4_col1\" class=\"data row4 col1\" >0.834234</td>\n",
1016
+ " <td id=\"T_08606_row4_col2\" class=\"data row4 col2\" >0.840194</td>\n",
1017
+ " <td id=\"T_08606_row4_col3\" class=\"data row4 col3\" >0.821006</td>\n",
1018
+ " <td id=\"T_08606_row4_col4\" class=\"data row4 col4\" >0.829717</td>\n",
1019
+ " <td id=\"T_08606_row4_col5\" class=\"data row4 col5\" >0.830547</td>\n",
1020
+ " </tr>\n",
1021
+ " <tr>\n",
1022
+ " <th id=\"T_08606_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
1023
+ " <td id=\"T_08606_row5_col0\" class=\"data row5 col0\" >0.771429</td>\n",
1024
+ " <td id=\"T_08606_row5_col1\" class=\"data row5 col1\" >0.778151</td>\n",
1025
+ " <td id=\"T_08606_row5_col2\" class=\"data row5 col2\" >0.777591</td>\n",
1026
+ " <td id=\"T_08606_row5_col3\" class=\"data row5 col3\" >0.777529</td>\n",
1027
+ " <td id=\"T_08606_row5_col4\" class=\"data row5 col4\" >0.771365</td>\n",
1028
+ " <td id=\"T_08606_row5_col5\" class=\"data row5 col5\" >0.772781</td>\n",
1029
+ " </tr>\n",
1030
+ " <tr>\n",
1031
+ " <th id=\"T_08606_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
1032
+ " <td id=\"T_08606_row6_col0\" class=\"data row6 col0\" >0.890122</td>\n",
1033
+ " <td id=\"T_08606_row6_col1\" class=\"data row6 col1\" >0.890561</td>\n",
1034
+ " <td id=\"T_08606_row6_col2\" class=\"data row6 col2\" >0.897321</td>\n",
1035
+ " <td id=\"T_08606_row6_col3\" class=\"data row6 col3\" >0.887396</td>\n",
1036
+ " <td id=\"T_08606_row6_col4\" class=\"data row6 col4\" >0.891078</td>\n",
1037
+ " <td id=\"T_08606_row6_col5\" class=\"data row6 col5\" >0.893466</td>\n",
1038
+ " </tr>\n",
1039
+ " <tr>\n",
1040
+ " <th id=\"T_08606_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
1041
+ " <td id=\"T_08606_row7_col0\" class=\"data row7 col0\" >-0.116884</td>\n",
1042
+ " <td id=\"T_08606_row7_col1\" class=\"data row7 col1\" >-0.114867</td>\n",
1043
+ " <td id=\"T_08606_row7_col2\" class=\"data row7 col2\" >-0.111343</td>\n",
1044
+ " <td id=\"T_08606_row7_col3\" class=\"data row7 col3\" >-0.117719</td>\n",
1045
+ " <td id=\"T_08606_row7_col4\" class=\"data row7 col4\" >-0.116295</td>\n",
1046
+ " <td id=\"T_08606_row7_col5\" class=\"data row7 col5\" >-0.115285</td>\n",
1047
+ " </tr>\n",
1048
+ " <tr>\n",
1049
+ " <th id=\"T_08606_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
1050
+ " <td id=\"T_08606_row8_col0\" class=\"data row8 col0\" >-0.607395</td>\n",
1051
+ " <td id=\"T_08606_row8_col1\" class=\"data row8 col1\" >-0.579640</td>\n",
1052
+ " <td id=\"T_08606_row8_col2\" class=\"data row8 col2\" >-0.536542</td>\n",
1053
+ " <td id=\"T_08606_row8_col3\" class=\"data row8 col3\" >-0.614554</td>\n",
1054
+ " <td id=\"T_08606_row8_col4\" class=\"data row8 col4\" >-0.631888</td>\n",
1055
+ " <td id=\"T_08606_row8_col5\" class=\"data row8 col5\" >-0.562042</td>\n",
1056
+ " </tr>\n",
1057
+ " </tbody>\n",
1058
+ "</table>\n"
1059
+ ],
1060
+ "text/plain": [
1061
+ "<pandas.io.formats.style.Styler at 0x1a1a6bb4820>"
1062
+ ]
1063
+ },
1064
+ "metadata": {},
1065
+ "output_type": "display_data"
1066
+ },
1067
+ {
1068
+ "data": {
1069
+ "text/html": [
1070
+ "<style type=\"text/css\">\n",
1071
+ "</style>\n",
1072
+ "<table id=\"T_14134\">\n",
1073
+ " <caption>GBC Validation Scores</caption>\n",
1074
+ " <thead>\n",
1075
+ " <tr>\n",
1076
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
1077
+ " <th id=\"T_14134_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
1078
+ " <th id=\"T_14134_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
1079
+ " <th id=\"T_14134_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
1080
+ " <th id=\"T_14134_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
1081
+ " <th id=\"T_14134_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
1082
+ " <th id=\"T_14134_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
1083
+ " </tr>\n",
1084
+ " </thead>\n",
1085
+ " <tbody>\n",
1086
+ " <tr>\n",
1087
+ " <th id=\"T_14134_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
1088
+ " <td id=\"T_14134_row0_col0\" class=\"data row0 col0\" >4.591437</td>\n",
1089
+ " <td id=\"T_14134_row0_col1\" class=\"data row0 col1\" >4.437213</td>\n",
1090
+ " <td id=\"T_14134_row0_col2\" class=\"data row0 col2\" >4.121067</td>\n",
1091
+ " <td id=\"T_14134_row0_col3\" class=\"data row0 col3\" >4.142180</td>\n",
1092
+ " <td id=\"T_14134_row0_col4\" class=\"data row0 col4\" >4.113901</td>\n",
1093
+ " <td id=\"T_14134_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
1094
+ " </tr>\n",
1095
+ " <tr>\n",
1096
+ " <th id=\"T_14134_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
1097
+ " <td id=\"T_14134_row1_col0\" class=\"data row1 col0\" >0.055993</td>\n",
1098
+ " <td id=\"T_14134_row1_col1\" class=\"data row1 col1\" >0.048113</td>\n",
1099
+ " <td id=\"T_14134_row1_col2\" class=\"data row1 col2\" >0.049492</td>\n",
1100
+ " <td id=\"T_14134_row1_col3\" class=\"data row1 col3\" >0.050163</td>\n",
1101
+ " <td id=\"T_14134_row1_col4\" class=\"data row1 col4\" >0.055706</td>\n",
1102
+ " <td id=\"T_14134_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
1103
+ " </tr>\n",
1104
+ " <tr>\n",
1105
+ " <th id=\"T_14134_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
1106
+ " <td id=\"T_14134_row2_col0\" class=\"data row2 col0\" >0.871113</td>\n",
1107
+ " <td id=\"T_14134_row2_col1\" class=\"data row2 col1\" >0.873207</td>\n",
1108
+ " <td id=\"T_14134_row2_col2\" class=\"data row2 col2\" >0.878639</td>\n",
1109
+ " <td id=\"T_14134_row2_col3\" class=\"data row2 col3\" >0.870052</td>\n",
1110
+ " <td id=\"T_14134_row2_col4\" class=\"data row2 col4\" >0.870157</td>\n",
1111
+ " <td id=\"T_14134_row2_col5\" class=\"data row2 col5\" >0.871054</td>\n",
1112
+ " </tr>\n",
1113
+ " <tr>\n",
1114
+ " <th id=\"T_14134_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
1115
+ " <td id=\"T_14134_row3_col0\" class=\"data row3 col0\" >0.817169</td>\n",
1116
+ " <td id=\"T_14134_row3_col1\" class=\"data row3 col1\" >0.820831</td>\n",
1117
+ " <td id=\"T_14134_row3_col2\" class=\"data row3 col2\" >0.827709</td>\n",
1118
+ " <td id=\"T_14134_row3_col3\" class=\"data row3 col3\" >0.817043</td>\n",
1119
+ " <td id=\"T_14134_row3_col4\" class=\"data row3 col4\" >0.817109</td>\n",
1120
+ " <td id=\"T_14134_row3_col5\" class=\"data row3 col5\" >0.817560</td>\n",
1121
+ " </tr>\n",
1122
+ " <tr>\n",
1123
+ " <th id=\"T_14134_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
1124
+ " <td id=\"T_14134_row4_col0\" class=\"data row4 col0\" >0.869744</td>\n",
1125
+ " <td id=\"T_14134_row4_col1\" class=\"data row4 col1\" >0.869865</td>\n",
1126
+ " <td id=\"T_14134_row4_col2\" class=\"data row4 col2\" >0.881850</td>\n",
1127
+ " <td id=\"T_14134_row4_col3\" class=\"data row4 col3\" >0.862166</td>\n",
1128
+ " <td id=\"T_14134_row4_col4\" class=\"data row4 col4\" >0.862660</td>\n",
1129
+ " <td id=\"T_14134_row4_col5\" class=\"data row4 col5\" >0.867518</td>\n",
1130
+ " </tr>\n",
1131
+ " <tr>\n",
1132
+ " <th id=\"T_14134_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
1133
+ " <td id=\"T_14134_row5_col0\" class=\"data row5 col0\" >0.770588</td>\n",
1134
+ " <td id=\"T_14134_row5_col1\" class=\"data row5 col1\" >0.777031</td>\n",
1135
+ " <td id=\"T_14134_row5_col2\" class=\"data row5 col2\" >0.779832</td>\n",
1136
+ " <td id=\"T_14134_row5_col3\" class=\"data row5 col3\" >0.776408</td>\n",
1137
+ " <td id=\"T_14134_row5_col4\" class=\"data row5 col4\" >0.776128</td>\n",
1138
+ " <td id=\"T_14134_row5_col5\" class=\"data row5 col5\" >0.773042</td>\n",
1139
+ " </tr>\n",
1140
+ " <tr>\n",
1141
+ " <th id=\"T_14134_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
1142
+ " <td id=\"T_14134_row6_col0\" class=\"data row6 col0\" >0.907041</td>\n",
1143
+ " <td id=\"T_14134_row6_col1\" class=\"data row6 col1\" >0.908041</td>\n",
1144
+ " <td id=\"T_14134_row6_col2\" class=\"data row6 col2\" >0.911930</td>\n",
1145
+ " <td id=\"T_14134_row6_col3\" class=\"data row6 col3\" >0.906283</td>\n",
1146
+ " <td id=\"T_14134_row6_col4\" class=\"data row6 col4\" >0.909348</td>\n",
1147
+ " <td id=\"T_14134_row6_col5\" class=\"data row6 col5\" >0.908648</td>\n",
1148
+ " </tr>\n",
1149
+ " <tr>\n",
1150
+ " <th id=\"T_14134_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
1151
+ " <td id=\"T_14134_row7_col0\" class=\"data row7 col0\" >-0.105054</td>\n",
1152
+ " <td id=\"T_14134_row7_col1\" class=\"data row7 col1\" >-0.103463</td>\n",
1153
+ " <td id=\"T_14134_row7_col2\" class=\"data row7 col2\" >-0.099338</td>\n",
1154
+ " <td id=\"T_14134_row7_col3\" class=\"data row7 col3\" >-0.104658</td>\n",
1155
+ " <td id=\"T_14134_row7_col4\" class=\"data row7 col4\" >-0.104459</td>\n",
1156
+ " <td id=\"T_14134_row7_col5\" class=\"data row7 col5\" >-0.104280</td>\n",
1157
+ " </tr>\n",
1158
+ " <tr>\n",
1159
+ " <th id=\"T_14134_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
1160
+ " <td id=\"T_14134_row8_col0\" class=\"data row8 col0\" >-0.352792</td>\n",
1161
+ " <td id=\"T_14134_row8_col1\" class=\"data row8 col1\" >-0.348499</td>\n",
1162
+ " <td id=\"T_14134_row8_col2\" class=\"data row8 col2\" >-0.338605</td>\n",
1163
+ " <td id=\"T_14134_row8_col3\" class=\"data row8 col3\" >-0.351285</td>\n",
1164
+ " <td id=\"T_14134_row8_col4\" class=\"data row8 col4\" >-0.350193</td>\n",
1165
+ " <td id=\"T_14134_row8_col5\" class=\"data row8 col5\" >-0.350152</td>\n",
1166
+ " </tr>\n",
1167
+ " </tbody>\n",
1168
+ "</table>\n"
1169
+ ],
1170
+ "text/plain": [
1171
+ "<pandas.io.formats.style.Styler at 0x1a1a86dd240>"
1172
+ ]
1173
+ },
1174
+ "metadata": {},
1175
+ "output_type": "display_data"
1176
+ },
1177
+ {
1178
+ "data": {
1179
+ "text/html": [
1180
+ "<style type=\"text/css\">\n",
1181
+ "</style>\n",
1182
+ "<table id=\"T_25121\">\n",
1183
+ " <caption>XGB Validation Scores</caption>\n",
1184
+ " <thead>\n",
1185
+ " <tr>\n",
1186
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
1187
+ " <th id=\"T_25121_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
1188
+ " <th id=\"T_25121_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
1189
+ " <th id=\"T_25121_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
1190
+ " <th id=\"T_25121_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
1191
+ " <th id=\"T_25121_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
1192
+ " <th id=\"T_25121_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
1193
+ " </tr>\n",
1194
+ " </thead>\n",
1195
+ " <tbody>\n",
1196
+ " <tr>\n",
1197
+ " <th id=\"T_25121_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
1198
+ " <td id=\"T_25121_row0_col0\" class=\"data row0 col0\" >3.802029</td>\n",
1199
+ " <td id=\"T_25121_row0_col1\" class=\"data row0 col1\" >3.036764</td>\n",
1200
+ " <td id=\"T_25121_row0_col2\" class=\"data row0 col2\" >2.979647</td>\n",
1201
+ " <td id=\"T_25121_row0_col3\" class=\"data row0 col3\" >2.177232</td>\n",
1202
+ " <td id=\"T_25121_row0_col4\" class=\"data row0 col4\" >2.287098</td>\n",
1203
+ " <td id=\"T_25121_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
1204
+ " </tr>\n",
1205
+ " <tr>\n",
1206
+ " <th id=\"T_25121_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
1207
+ " <td id=\"T_25121_row1_col0\" class=\"data row1 col0\" >0.069013</td>\n",
1208
+ " <td id=\"T_25121_row1_col1\" class=\"data row1 col1\" >0.071819</td>\n",
1209
+ " <td id=\"T_25121_row1_col2\" class=\"data row1 col2\" >0.049402</td>\n",
1210
+ " <td id=\"T_25121_row1_col3\" class=\"data row1 col3\" >0.057279</td>\n",
1211
+ " <td id=\"T_25121_row1_col4\" class=\"data row1 col4\" >0.050020</td>\n",
1212
+ " <td id=\"T_25121_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
1213
+ " </tr>\n",
1214
+ " <tr>\n",
1215
+ " <th id=\"T_25121_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
1216
+ " <td id=\"T_25121_row2_col0\" class=\"data row2 col0\" >0.860224</td>\n",
1217
+ " <td id=\"T_25121_row2_col1\" class=\"data row2 col1\" >0.851848</td>\n",
1218
+ " <td id=\"T_25121_row2_col2\" class=\"data row2 col2\" >0.854136</td>\n",
1219
+ " <td id=\"T_25121_row2_col3\" class=\"data row2 col3\" >0.853298</td>\n",
1220
+ " <td id=\"T_25121_row2_col4\" class=\"data row2 col4\" >0.856021</td>\n",
1221
+ " <td id=\"T_25121_row2_col5\" class=\"data row2 col5\" >0.854344</td>\n",
1222
+ " </tr>\n",
1223
+ " <tr>\n",
1224
+ " <th id=\"T_25121_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
1225
+ " <td id=\"T_25121_row3_col0\" class=\"data row3 col0\" >0.814145</td>\n",
1226
+ " <td id=\"T_25121_row3_col1\" class=\"data row3 col1\" >0.804747</td>\n",
1227
+ " <td id=\"T_25121_row3_col2\" class=\"data row3 col2\" >0.808259</td>\n",
1228
+ " <td id=\"T_25121_row3_col3\" class=\"data row3 col3\" >0.807370</td>\n",
1229
+ " <td id=\"T_25121_row3_col4\" class=\"data row3 col4\" >0.810371</td>\n",
1230
+ " <td id=\"T_25121_row3_col5\" class=\"data row3 col5\" >0.808283</td>\n",
1231
+ " </tr>\n",
1232
+ " <tr>\n",
1233
+ " <th id=\"T_25121_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
1234
+ " <td id=\"T_25121_row4_col0\" class=\"data row4 col0\" >0.809300</td>\n",
1235
+ " <td id=\"T_25121_row4_col1\" class=\"data row4 col1\" >0.793038</td>\n",
1236
+ " <td id=\"T_25121_row4_col2\" class=\"data row4 col2\" >0.794587</td>\n",
1237
+ " <td id=\"T_25121_row4_col3\" class=\"data row4 col3\" >0.792657</td>\n",
1238
+ " <td id=\"T_25121_row4_col4\" class=\"data row4 col4\" >0.797936</td>\n",
1239
+ " <td id=\"T_25121_row4_col5\" class=\"data row4 col5\" >0.795443</td>\n",
1240
+ " </tr>\n",
1241
+ " <tr>\n",
1242
+ " <th id=\"T_25121_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
1243
+ " <td id=\"T_25121_row5_col0\" class=\"data row5 col0\" >0.819048</td>\n",
1244
+ " <td id=\"T_25121_row5_col1\" class=\"data row5 col1\" >0.816807</td>\n",
1245
+ " <td id=\"T_25121_row5_col2\" class=\"data row5 col2\" >0.822409</td>\n",
1246
+ " <td id=\"T_25121_row5_col3\" class=\"data row5 col3\" >0.822639</td>\n",
1247
+ " <td id=\"T_25121_row5_col4\" class=\"data row5 col4\" >0.823200</td>\n",
1248
+ " <td id=\"T_25121_row5_col5\" class=\"data row5 col5\" >0.821545</td>\n",
1249
+ " </tr>\n",
1250
+ " <tr>\n",
1251
+ " <th id=\"T_25121_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
1252
+ " <td id=\"T_25121_row6_col0\" class=\"data row6 col0\" >0.908407</td>\n",
1253
+ " <td id=\"T_25121_row6_col1\" class=\"data row6 col1\" >0.906379</td>\n",
1254
+ " <td id=\"T_25121_row6_col2\" class=\"data row6 col2\" >0.910833</td>\n",
1255
+ " <td id=\"T_25121_row6_col3\" class=\"data row6 col3\" >0.907507</td>\n",
1256
+ " <td id=\"T_25121_row6_col4\" class=\"data row6 col4\" >0.908959</td>\n",
1257
+ " <td id=\"T_25121_row6_col5\" class=\"data row6 col5\" >0.908681</td>\n",
1258
+ " </tr>\n",
1259
+ " <tr>\n",
1260
+ " <th id=\"T_25121_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
1261
+ " <td id=\"T_25121_row7_col0\" class=\"data row7 col0\" >-0.116893</td>\n",
1262
+ " <td id=\"T_25121_row7_col1\" class=\"data row7 col1\" >-0.119319</td>\n",
1263
+ " <td id=\"T_25121_row7_col2\" class=\"data row7 col2\" >-0.116034</td>\n",
1264
+ " <td id=\"T_25121_row7_col3\" class=\"data row7 col3\" >-0.119313</td>\n",
1265
+ " <td id=\"T_25121_row7_col4\" class=\"data row7 col4\" >-0.118294</td>\n",
1266
+ " <td id=\"T_25121_row7_col5\" class=\"data row7 col5\" >-0.118266</td>\n",
1267
+ " </tr>\n",
1268
+ " <tr>\n",
1269
+ " <th id=\"T_25121_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
1270
+ " <td id=\"T_25121_row8_col0\" class=\"data row8 col0\" >-0.393473</td>\n",
1271
+ " <td id=\"T_25121_row8_col1\" class=\"data row8 col1\" >-0.395306</td>\n",
1272
+ " <td id=\"T_25121_row8_col2\" class=\"data row8 col2\" >-0.384403</td>\n",
1273
+ " <td id=\"T_25121_row8_col3\" class=\"data row8 col3\" >-0.397352</td>\n",
1274
+ " <td id=\"T_25121_row8_col4\" class=\"data row8 col4\" >-0.394224</td>\n",
1275
+ " <td id=\"T_25121_row8_col5\" class=\"data row8 col5\" >-0.392001</td>\n",
1276
+ " </tr>\n",
1277
+ " </tbody>\n",
1278
+ "</table>\n"
1279
+ ],
1280
+ "text/plain": [
1281
+ "<pandas.io.formats.style.Styler at 0x1a1a8656f80>"
1282
+ ]
1283
+ },
1284
+ "metadata": {},
1285
+ "output_type": "display_data"
1286
+ }
1287
+ ],
1288
+ "source": [
1289
+ "# XGB hyperparameter that deals with unbalanced\n",
1290
+ "scale_pos_weight = Y.mean()**-1\n",
1291
+ "\n",
1292
+ "# Creating the model objects\n",
1293
+ "cls_lr = LogisticRegression(\n",
1294
+ " class_weight=\"balanced\", # Hyperparameter to deal with unbalanced output\n",
1295
+ " random_state=lucky_num)\n",
1296
+ "# cls_svm = SVC(random_state=lucky_num) # Remove due its resource consumption and worst results\n",
1297
+ "cls_NB = GaussianNB()\n",
1298
+ "cls_knn = KNeighborsClassifier()\n",
1299
+ "cls_rf = RandomForestClassifier(\n",
1300
+ " random_state=lucky_num,\n",
1301
+ " class_weight=\"balanced_subsample\") # Hyperparameter to deal with unbalanced output\n",
1302
+ "cls_gbc = GradientBoostingClassifier(random_state=lucky_num)\n",
1303
+ "cls_xgb = xgb.XGBClassifier(\n",
1304
+ " objective=\"binary:logistic\",\n",
1305
+ " verbose=None,\n",
1306
+ " random_state=lucky_num,\n",
1307
+ " scale_pos_weight = scale_pos_weight)\n",
1308
+ "\n",
1309
+ "# Lists to iterate on our modeling function\n",
1310
+ "cls_name = [\"LR\", \"NB\", \"KNN\", \"RF\", \"GBC\", \"XGB\"]\n",
1311
+ "cls_list = [cls_lr, cls_NB, cls_knn, cls_rf, cls_gbc, cls_xgb]\n",
1312
+ "\n",
1313
+ "mdl_summaries = []\n",
1314
+ "for name, inst in zip(cls_name, cls_list):\n",
1315
+ " mdl_list = create_model(name, inst)\n",
1316
+ " mdl_list = [name] + mdl_list\n",
1317
+ " mdl_summaries.append(mdl_list)\n",
1318
+ "\n",
1319
+ "df_mdl = pd.DataFrame(\n",
1320
+ " mdl_summaries,\n",
1321
+ " columns=[\n",
1322
+ " \"model\",\n",
1323
+ " \"test_accuracy\",\n",
1324
+ " \"test_f1\",\n",
1325
+ " \"test_precision\",\n",
1326
+ " \"test_recall\",\n",
1327
+ " \"test_roc_auc\",\n",
1328
+ " \"test_brier\",\n",
1329
+ " \"test_log_loss\"])"
1330
+ ]
1331
+ },
1332
+ {
1333
+ "cell_type": "code",
1334
+ "execution_count": 9,
1335
+ "metadata": {},
1336
+ "outputs": [
1337
+ {
1338
+ "data": {
1339
+ "text/html": [
1340
+ "<style type=\"text/css\">\n",
1341
+ "</style>\n",
1342
+ "<table id=\"T_3ba63\">\n",
1343
+ " <caption>Test set validation scores</caption>\n",
1344
+ " <thead>\n",
1345
+ " <tr>\n",
1346
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
1347
+ " <th id=\"T_3ba63_level0_col0\" class=\"col_heading level0 col0\" >model</th>\n",
1348
+ " <th id=\"T_3ba63_level0_col1\" class=\"col_heading level0 col1\" >test_accuracy</th>\n",
1349
+ " <th id=\"T_3ba63_level0_col2\" class=\"col_heading level0 col2\" >test_f1</th>\n",
1350
+ " <th id=\"T_3ba63_level0_col3\" class=\"col_heading level0 col3\" >test_precision</th>\n",
1351
+ " <th id=\"T_3ba63_level0_col4\" class=\"col_heading level0 col4\" >test_recall</th>\n",
1352
+ " <th id=\"T_3ba63_level0_col5\" class=\"col_heading level0 col5\" >test_roc_auc</th>\n",
1353
+ " <th id=\"T_3ba63_level0_col6\" class=\"col_heading level0 col6\" >test_brier</th>\n",
1354
+ " <th id=\"T_3ba63_level0_col7\" class=\"col_heading level0 col7\" >test_log_loss</th>\n",
1355
+ " </tr>\n",
1356
+ " </thead>\n",
1357
+ " <tbody>\n",
1358
+ " <tr>\n",
1359
+ " <th id=\"T_3ba63_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
1360
+ " <td id=\"T_3ba63_row0_col0\" class=\"data row0 col0\" >GBC</td>\n",
1361
+ " <td id=\"T_3ba63_row0_col1\" class=\"data row0 col1\" >0.871054</td>\n",
1362
+ " <td id=\"T_3ba63_row0_col2\" class=\"data row0 col2\" >0.817560</td>\n",
1363
+ " <td id=\"T_3ba63_row0_col3\" class=\"data row0 col3\" >0.867518</td>\n",
1364
+ " <td id=\"T_3ba63_row0_col4\" class=\"data row0 col4\" >0.773042</td>\n",
1365
+ " <td id=\"T_3ba63_row0_col5\" class=\"data row0 col5\" >0.908648</td>\n",
1366
+ " <td id=\"T_3ba63_row0_col6\" class=\"data row0 col6\" >0.104280</td>\n",
1367
+ " <td id=\"T_3ba63_row0_col7\" class=\"data row0 col7\" >0.350152</td>\n",
1368
+ " </tr>\n",
1369
+ " <tr>\n",
1370
+ " <th id=\"T_3ba63_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
1371
+ " <td id=\"T_3ba63_row1_col0\" class=\"data row1 col0\" >LR</td>\n",
1372
+ " <td id=\"T_3ba63_row1_col1\" class=\"data row1 col1\" >0.865924</td>\n",
1373
+ " <td id=\"T_3ba63_row1_col2\" class=\"data row1 col2\" >0.814469</td>\n",
1374
+ " <td id=\"T_3ba63_row1_col3\" class=\"data row1 col3\" >0.843439</td>\n",
1375
+ " <td id=\"T_3ba63_row1_col4\" class=\"data row1 col4\" >0.787423</td>\n",
1376
+ " <td id=\"T_3ba63_row1_col5\" class=\"data row1 col5\" >0.904458</td>\n",
1377
+ " <td id=\"T_3ba63_row1_col6\" class=\"data row1 col6\" >0.110435</td>\n",
1378
+ " <td id=\"T_3ba63_row1_col7\" class=\"data row1 col7\" >0.370350</td>\n",
1379
+ " </tr>\n",
1380
+ " <tr>\n",
1381
+ " <th id=\"T_3ba63_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
1382
+ " <td id=\"T_3ba63_row2_col0\" class=\"data row2 col0\" >XGB</td>\n",
1383
+ " <td id=\"T_3ba63_row2_col1\" class=\"data row2 col1\" >0.854344</td>\n",
1384
+ " <td id=\"T_3ba63_row2_col2\" class=\"data row2 col2\" >0.808283</td>\n",
1385
+ " <td id=\"T_3ba63_row2_col3\" class=\"data row2 col3\" >0.795443</td>\n",
1386
+ " <td id=\"T_3ba63_row2_col4\" class=\"data row2 col4\" >0.821545</td>\n",
1387
+ " <td id=\"T_3ba63_row2_col5\" class=\"data row2 col5\" >0.908681</td>\n",
1388
+ " <td id=\"T_3ba63_row2_col6\" class=\"data row2 col6\" >0.118266</td>\n",
1389
+ " <td id=\"T_3ba63_row2_col7\" class=\"data row2 col7\" >0.392001</td>\n",
1390
+ " </tr>\n",
1391
+ " <tr>\n",
1392
+ " <th id=\"T_3ba63_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
1393
+ " <td id=\"T_3ba63_row3_col0\" class=\"data row3 col0\" >RF</td>\n",
1394
+ " <td id=\"T_3ba63_row3_col1\" class=\"data row3 col1\" >0.856152</td>\n",
1395
+ " <td id=\"T_3ba63_row3_col2\" class=\"data row3 col2\" >0.800623</td>\n",
1396
+ " <td id=\"T_3ba63_row3_col3\" class=\"data row3 col3\" >0.830547</td>\n",
1397
+ " <td id=\"T_3ba63_row3_col4\" class=\"data row3 col4\" >0.772781</td>\n",
1398
+ " <td id=\"T_3ba63_row3_col5\" class=\"data row3 col5\" >0.893466</td>\n",
1399
+ " <td id=\"T_3ba63_row3_col6\" class=\"data row3 col6\" >0.115285</td>\n",
1400
+ " <td id=\"T_3ba63_row3_col7\" class=\"data row3 col7\" >0.562042</td>\n",
1401
+ " </tr>\n",
1402
+ " <tr>\n",
1403
+ " <th id=\"T_3ba63_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
1404
+ " <td id=\"T_3ba63_row4_col0\" class=\"data row4 col0\" >KNN</td>\n",
1405
+ " <td id=\"T_3ba63_row4_col1\" class=\"data row4 col1\" >0.843692</td>\n",
1406
+ " <td id=\"T_3ba63_row4_col2\" class=\"data row4 col2\" >0.779698</td>\n",
1407
+ " <td id=\"T_3ba63_row4_col3\" class=\"data row4 col3\" >0.823778</td>\n",
1408
+ " <td id=\"T_3ba63_row4_col4\" class=\"data row4 col4\" >0.740097</td>\n",
1409
+ " <td id=\"T_3ba63_row4_col5\" class=\"data row4 col5\" >0.872155</td>\n",
1410
+ " <td id=\"T_3ba63_row4_col6\" class=\"data row4 col6\" >0.128215</td>\n",
1411
+ " <td id=\"T_3ba63_row4_col7\" class=\"data row4 col7\" >1.877810</td>\n",
1412
+ " </tr>\n",
1413
+ " <tr>\n",
1414
+ " <th id=\"T_3ba63_level0_row5\" class=\"row_heading level0 row5\" >5</th>\n",
1415
+ " <td id=\"T_3ba63_row5_col0\" class=\"data row5 col0\" >NB</td>\n",
1416
+ " <td id=\"T_3ba63_row5_col1\" class=\"data row5 col1\" >0.766637</td>\n",
1417
+ " <td id=\"T_3ba63_row5_col2\" class=\"data row5 col2\" >0.664795</td>\n",
1418
+ " <td id=\"T_3ba63_row5_col3\" class=\"data row5 col3\" >0.717684</td>\n",
1419
+ " <td id=\"T_3ba63_row5_col4\" class=\"data row5 col4\" >0.619166</td>\n",
1420
+ " <td id=\"T_3ba63_row5_col5\" class=\"data row5 col5\" >0.848834</td>\n",
1421
+ " <td id=\"T_3ba63_row5_col6\" class=\"data row5 col6\" >0.208278</td>\n",
1422
+ " <td id=\"T_3ba63_row5_col7\" class=\"data row5 col7\" >1.761326</td>\n",
1423
+ " </tr>\n",
1424
+ " </tbody>\n",
1425
+ "</table>\n"
1426
+ ],
1427
+ "text/plain": [
1428
+ "<pandas.io.formats.style.Styler at 0x1a1a8656bc0>"
1429
+ ]
1430
+ },
1431
+ "metadata": {},
1432
+ "output_type": "display_data"
1433
+ }
1434
+ ],
1435
+ "source": [
1436
+ "df_mdl.sort_values(\n",
1437
+ " \"test_f1\",\n",
1438
+ " ascending=False,\n",
1439
+ " inplace=True,\n",
1440
+ " ignore_index=True)\n",
1441
+ "\n",
1442
+ "display(df_mdl.style.set_caption(\"Test set validation scores\"))"
1443
+ ]
1444
+ },
1445
+ {
1446
+ "cell_type": "markdown",
1447
+ "metadata": {},
1448
+ "source": [
1449
+ "GBC, LR, XGB and RF preset great results! We have two ways here: hyperparameters tunning or creating a composite model. Let's begin with the composite model.\n"
1450
+ ]
1451
+ },
1452
+ {
1453
+ "cell_type": "code",
1454
+ "execution_count": 10,
1455
+ "metadata": {},
1456
+ "outputs": [
1457
+ {
1458
+ "data": {
1459
+ "text/html": [
1460
+ "<style type=\"text/css\">\n",
1461
+ "</style>\n",
1462
+ "<table id=\"T_04769\">\n",
1463
+ " <caption>Test set validation scores for Composite Model</caption>\n",
1464
+ " <thead>\n",
1465
+ " <tr>\n",
1466
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
1467
+ " <th id=\"T_04769_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
1468
+ " <th id=\"T_04769_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
1469
+ " <th id=\"T_04769_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
1470
+ " <th id=\"T_04769_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
1471
+ " <th id=\"T_04769_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
1472
+ " <th id=\"T_04769_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
1473
+ " </tr>\n",
1474
+ " </thead>\n",
1475
+ " <tbody>\n",
1476
+ " <tr>\n",
1477
+ " <th id=\"T_04769_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
1478
+ " <td id=\"T_04769_row0_col0\" class=\"data row0 col0\" >10.109613</td>\n",
1479
+ " <td id=\"T_04769_row0_col1\" class=\"data row0 col1\" >11.766011</td>\n",
1480
+ " <td id=\"T_04769_row0_col2\" class=\"data row0 col2\" >11.450818</td>\n",
1481
+ " <td id=\"T_04769_row0_col3\" class=\"data row0 col3\" >11.737634</td>\n",
1482
+ " <td id=\"T_04769_row0_col4\" class=\"data row0 col4\" >12.702598</td>\n",
1483
+ " <td id=\"T_04769_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
1484
+ " </tr>\n",
1485
+ " <tr>\n",
1486
+ " <th id=\"T_04769_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
1487
+ " <td id=\"T_04769_row1_col0\" class=\"data row1 col0\" >0.490518</td>\n",
1488
+ " <td id=\"T_04769_row1_col1\" class=\"data row1 col1\" >0.532695</td>\n",
1489
+ " <td id=\"T_04769_row1_col2\" class=\"data row1 col2\" >0.529459</td>\n",
1490
+ " <td id=\"T_04769_row1_col3\" class=\"data row1 col3\" >0.549051</td>\n",
1491
+ " <td id=\"T_04769_row1_col4\" class=\"data row1 col4\" >0.586749</td>\n",
1492
+ " <td id=\"T_04769_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
1493
+ " </tr>\n",
1494
+ " <tr>\n",
1495
+ " <th id=\"T_04769_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
1496
+ " <td id=\"T_04769_row2_col0\" class=\"data row2 col0\" >0.870799</td>\n",
1497
+ " <td id=\"T_04769_row2_col1\" class=\"data row2 col1\" >0.871532</td>\n",
1498
+ " <td id=\"T_04769_row2_col2\" class=\"data row2 col2\" >0.875497</td>\n",
1499
+ " <td id=\"T_04769_row2_col3\" class=\"data row2 col3\" >0.869948</td>\n",
1500
+ " <td id=\"T_04769_row2_col4\" class=\"data row2 col4\" >0.869215</td>\n",
1501
+ " <td id=\"T_04769_row2_col5\" class=\"data row2 col5\" >0.869002</td>\n",
1502
+ " </tr>\n",
1503
+ " <tr>\n",
1504
+ " <th id=\"T_04769_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
1505
+ " <td id=\"T_04769_row3_col0\" class=\"data row3 col0\" >0.818689</td>\n",
1506
+ " <td id=\"T_04769_row3_col1\" class=\"data row3 col1\" >0.820797</td>\n",
1507
+ " <td id=\"T_04769_row3_col2\" class=\"data row3 col2\" >0.826297</td>\n",
1508
+ " <td id=\"T_04769_row3_col3\" class=\"data row3 col3\" >0.819319</td>\n",
1509
+ " <td id=\"T_04769_row3_col4\" class=\"data row3 col4\" >0.817531</td>\n",
1510
+ " <td id=\"T_04769_row3_col5\" class=\"data row3 col5\" >0.817283</td>\n",
1511
+ " </tr>\n",
1512
+ " <tr>\n",
1513
+ " <th id=\"T_04769_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
1514
+ " <td id=\"T_04769_row4_col0\" class=\"data row4 col0\" >0.860939</td>\n",
1515
+ " <td id=\"T_04769_row4_col1\" class=\"data row4 col1\" >0.857492</td>\n",
1516
+ " <td id=\"T_04769_row4_col2\" class=\"data row4 col2\" >0.863511</td>\n",
1517
+ " <td id=\"T_04769_row4_col3\" class=\"data row4 col3\" >0.852042</td>\n",
1518
+ " <td id=\"T_04769_row4_col4\" class=\"data row4 col4\" >0.854090</td>\n",
1519
+ " <td id=\"T_04769_row4_col5\" class=\"data row4 col5\" >0.853645</td>\n",
1520
+ " </tr>\n",
1521
+ " <tr>\n",
1522
+ " <th id=\"T_04769_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
1523
+ " <td id=\"T_04769_row5_col0\" class=\"data row5 col0\" >0.780392</td>\n",
1524
+ " <td id=\"T_04769_row5_col1\" class=\"data row5 col1\" >0.787115</td>\n",
1525
+ " <td id=\"T_04769_row5_col2\" class=\"data row5 col2\" >0.792157</td>\n",
1526
+ " <td id=\"T_04769_row5_col3\" class=\"data row5 col3\" >0.789017</td>\n",
1527
+ " <td id=\"T_04769_row5_col4\" class=\"data row5 col4\" >0.783973</td>\n",
1528
+ " <td id=\"T_04769_row5_col5\" class=\"data row5 col5\" >0.783893</td>\n",
1529
+ " </tr>\n",
1530
+ " <tr>\n",
1531
+ " <th id=\"T_04769_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
1532
+ " <td id=\"T_04769_row6_col0\" class=\"data row6 col0\" >0.909022</td>\n",
1533
+ " <td id=\"T_04769_row6_col1\" class=\"data row6 col1\" >0.908890</td>\n",
1534
+ " <td id=\"T_04769_row6_col2\" class=\"data row6 col2\" >0.912418</td>\n",
1535
+ " <td id=\"T_04769_row6_col3\" class=\"data row6 col3\" >0.907315</td>\n",
1536
+ " <td id=\"T_04769_row6_col4\" class=\"data row6 col4\" >0.910340</td>\n",
1537
+ " <td id=\"T_04769_row6_col5\" class=\"data row6 col5\" >0.910500</td>\n",
1538
+ " </tr>\n",
1539
+ " <tr>\n",
1540
+ " <th id=\"T_04769_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
1541
+ " <td id=\"T_04769_row7_col0\" class=\"data row7 col0\" >-0.105818</td>\n",
1542
+ " <td id=\"T_04769_row7_col1\" class=\"data row7 col1\" >-0.105354</td>\n",
1543
+ " <td id=\"T_04769_row7_col2\" class=\"data row7 col2\" >-0.101743</td>\n",
1544
+ " <td id=\"T_04769_row7_col3\" class=\"data row7 col3\" >-0.106567</td>\n",
1545
+ " <td id=\"T_04769_row7_col4\" class=\"data row7 col4\" >-0.105957</td>\n",
1546
+ " <td id=\"T_04769_row7_col5\" class=\"data row7 col5\" >-0.105743</td>\n",
1547
+ " </tr>\n",
1548
+ " <tr>\n",
1549
+ " <th id=\"T_04769_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
1550
+ " <td id=\"T_04769_row8_col0\" class=\"data row8 col0\" >-0.356051</td>\n",
1551
+ " <td id=\"T_04769_row8_col1\" class=\"data row8 col1\" >-0.353269</td>\n",
1552
+ " <td id=\"T_04769_row8_col2\" class=\"data row8 col2\" >-0.344184</td>\n",
1553
+ " <td id=\"T_04769_row8_col3\" class=\"data row8 col3\" >-0.357062</td>\n",
1554
+ " <td id=\"T_04769_row8_col4\" class=\"data row8 col4\" >-0.355010</td>\n",
1555
+ " <td id=\"T_04769_row8_col5\" class=\"data row8 col5\" >-0.353621</td>\n",
1556
+ " </tr>\n",
1557
+ " </tbody>\n",
1558
+ "</table>\n"
1559
+ ],
1560
+ "text/plain": [
1561
+ "<pandas.io.formats.style.Styler at 0x1a1859d3520>"
1562
+ ]
1563
+ },
1564
+ "metadata": {},
1565
+ "output_type": "display_data"
1566
+ }
1567
+ ],
1568
+ "source": [
1569
+ "# Selecting the models\n",
1570
+ "cls_name = [\"GBC\", \"XGB\", \"LR\", \"RF\",]\n",
1571
+ "cls_list = [cls_gbc, cls_xgb, cls_lr, cls_rf]\n",
1572
+ "\n",
1573
+ "# Training the voting classifier\n",
1574
+ "cls_vot = VotingClassifier([*zip(cls_name, cls_list)], voting=\"soft\")\n",
1575
+ "cls_vot.fit(X_train, y_train)\n",
1576
+ "\n",
1577
+ "# Using cross-validation to evaluate the model fitted\n",
1578
+ "cls_cross = cross_validate(\n",
1579
+ " estimator=cls_vot,\n",
1580
+ " X=X_train,\n",
1581
+ " y=y_train,\n",
1582
+ " cv=5,\n",
1583
+ " scoring=scores)\n",
1584
+ "\n",
1585
+ "df_vot = pd.DataFrame.from_dict(cls_cross, orient='index', columns=[\"CV\"+str(i) for i in range(1,6)])\n",
1586
+ "\n",
1587
+ "# Calculating score to test set\n",
1588
+ "accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value = eval_model(cls_vot)\n",
1589
+ "\n",
1590
+ "# Filling a dataframe to better presentation\n",
1591
+ "df_vot.at[\"test_accuracy\", \"TestSet\"] = accurancy\n",
1592
+ "df_vot.at[\"test_f1\", \"TestSet\"] = f1\n",
1593
+ "df_vot.at[\"test_recall\", \"TestSet\"] = recall\n",
1594
+ "df_vot.at[\"test_precision\", \"TestSet\"] = precision\n",
1595
+ "df_vot.at[\"test_roc_auc\", \"TestSet\"] = roc_auc\n",
1596
+ "df_vot.at[\"test_neg_brier_score\", \"TestSet\"] = -brier_score\n",
1597
+ "df_vot.at[\"test_neg_log_loss\", \"TestSet\"] = -log_loss_value\n",
1598
+ "\n",
1599
+ "display(df_vot.style.set_caption(\"Test set validation scores for Composite Model\"))"
1600
+ ]
1601
+ },
1602
+ {
1603
+ "cell_type": "markdown",
1604
+ "metadata": {},
1605
+ "source": [
1606
+ "The composite model does not present any evidence of overfitting. For now, we will use it on our app."
1607
+ ]
1608
+ },
1609
+ {
1610
+ "cell_type": "code",
1611
+ "execution_count": 11,
1612
+ "metadata": {},
1613
+ "outputs": [
1614
+ {
1615
+ "data": {
1616
+ "text/plain": [
1617
+ "['c:\\\\Users\\\\grego\\\\OneDrive\\\\Documentos\\\\Documentos Pessoais\\\\00_DataCamp\\\\09_VSC\\\\poa_car_accidents\\\\poa_car_accidents\\\\model\\\\model_feridos.pkl']"
1618
+ ]
1619
+ },
1620
+ "execution_count": 11,
1621
+ "metadata": {},
1622
+ "output_type": "execute_result"
1623
+ }
1624
+ ],
1625
+ "source": [
1626
+ "# Saving\n",
1627
+ "file_name = \"model_\" + output + '.pkl'\n",
1628
+ "jb.dump(cls_vot, path.join(path.abspath(\"./\"), file_name))"
1629
+ ]
1630
+ }
1631
+ ],
1632
+ "metadata": {
1633
+ "kernelspec": {
1634
+ "display_name": "Python 3.10.6 64-bit",
1635
+ "language": "python",
1636
+ "name": "python3"
1637
+ },
1638
+ "language_info": {
1639
+ "codemirror_mode": {
1640
+ "name": "ipython",
1641
+ "version": 3
1642
+ },
1643
+ "file_extension": ".py",
1644
+ "mimetype": "text/x-python",
1645
+ "name": "python",
1646
+ "nbconvert_exporter": "python",
1647
+ "pygments_lexer": "ipython3",
1648
+ "version": "3.10.6"
1649
+ },
1650
+ "orig_nbformat": 4,
1651
+ "vscode": {
1652
+ "interpreter": {
1653
+ "hash": "1372d04dbd71fdc5436c5d6e671c1b9287e750e86143c81b5a7ba0acaf653c5e"
1654
+ }
1655
+ }
1656
+ },
1657
+ "nbformat": 4,
1658
+ "nbformat_minor": 2
1659
+ }
model/model_feridos.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee61d7ec89de8ff0cf49904823a9a892556807449852976ede436fa059ed765a
3
+ size 248858704
model/model_feridos_gr.ipynb ADDED
@@ -0,0 +1,1659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 1. Introduction\n",
8
+ "\n",
9
+ "This notebook was written to train Porto Alegre Traffic Accidents Data after the first cleaning, processing, and transforming step. This was made in a notebook in the `data` folder. In truth, we will have 3 models.\n",
10
+ "\n",
11
+ "1. Predict the probability of injured people.\n",
12
+ "\n",
13
+ "2. Predict the probability of seriously injured people.\n",
14
+ "\n",
15
+ "3. Predict the probability of dead people in the event or after it.\n",
16
+ "\n",
17
+ "The path to training the models will be the same, just make some filtering on data and analyze the results properly."
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "# 2. Data Loading"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 5,
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "data": {
34
+ "text/html": [
35
+ "<div>\n",
36
+ "<style scoped>\n",
37
+ " .dataframe tbody tr th:only-of-type {\n",
38
+ " vertical-align: middle;\n",
39
+ " }\n",
40
+ "\n",
41
+ " .dataframe tbody tr th {\n",
42
+ " vertical-align: top;\n",
43
+ " }\n",
44
+ "\n",
45
+ " .dataframe thead th {\n",
46
+ " text-align: right;\n",
47
+ " }\n",
48
+ "</style>\n",
49
+ "<table border=\"1\" class=\"dataframe\">\n",
50
+ " <thead>\n",
51
+ " <tr style=\"text-align: right;\">\n",
52
+ " <th></th>\n",
53
+ " <th>0</th>\n",
54
+ " <th>1</th>\n",
55
+ " <th>2</th>\n",
56
+ " </tr>\n",
57
+ " </thead>\n",
58
+ " <tbody>\n",
59
+ " <tr>\n",
60
+ " <th>latitude</th>\n",
61
+ " <td>-30.009614</td>\n",
62
+ " <td>-30.0403</td>\n",
63
+ " <td>-30.069</td>\n",
64
+ " </tr>\n",
65
+ " <tr>\n",
66
+ " <th>longitude</th>\n",
67
+ " <td>-51.185581</td>\n",
68
+ " <td>-51.1958</td>\n",
69
+ " <td>-51.1437</td>\n",
70
+ " </tr>\n",
71
+ " <tr>\n",
72
+ " <th>feridos</th>\n",
73
+ " <td>True</td>\n",
74
+ " <td>True</td>\n",
75
+ " <td>True</td>\n",
76
+ " </tr>\n",
77
+ " <tr>\n",
78
+ " <th>feridos_gr</th>\n",
79
+ " <td>False</td>\n",
80
+ " <td>False</td>\n",
81
+ " <td>False</td>\n",
82
+ " </tr>\n",
83
+ " <tr>\n",
84
+ " <th>fatais</th>\n",
85
+ " <td>False</td>\n",
86
+ " <td>False</td>\n",
87
+ " <td>False</td>\n",
88
+ " </tr>\n",
89
+ " <tr>\n",
90
+ " <th>caminhao</th>\n",
91
+ " <td>False</td>\n",
92
+ " <td>False</td>\n",
93
+ " <td>False</td>\n",
94
+ " </tr>\n",
95
+ " <tr>\n",
96
+ " <th>moto</th>\n",
97
+ " <td>True</td>\n",
98
+ " <td>True</td>\n",
99
+ " <td>False</td>\n",
100
+ " </tr>\n",
101
+ " <tr>\n",
102
+ " <th>cars</th>\n",
103
+ " <td>True</td>\n",
104
+ " <td>True</td>\n",
105
+ " <td>True</td>\n",
106
+ " </tr>\n",
107
+ " <tr>\n",
108
+ " <th>transport</th>\n",
109
+ " <td>False</td>\n",
110
+ " <td>False</td>\n",
111
+ " <td>False</td>\n",
112
+ " </tr>\n",
113
+ " <tr>\n",
114
+ " <th>others</th>\n",
115
+ " <td>False</td>\n",
116
+ " <td>False</td>\n",
117
+ " <td>False</td>\n",
118
+ " </tr>\n",
119
+ " <tr>\n",
120
+ " <th>holiday</th>\n",
121
+ " <td>False</td>\n",
122
+ " <td>True</td>\n",
123
+ " <td>True</td>\n",
124
+ " </tr>\n",
125
+ " <tr>\n",
126
+ " <th>day_1</th>\n",
127
+ " <td>0</td>\n",
128
+ " <td>0</td>\n",
129
+ " <td>0</td>\n",
130
+ " </tr>\n",
131
+ " <tr>\n",
132
+ " <th>day_2</th>\n",
133
+ " <td>0</td>\n",
134
+ " <td>0</td>\n",
135
+ " <td>0</td>\n",
136
+ " </tr>\n",
137
+ " <tr>\n",
138
+ " <th>day_3</th>\n",
139
+ " <td>0</td>\n",
140
+ " <td>0</td>\n",
141
+ " <td>0</td>\n",
142
+ " </tr>\n",
143
+ " <tr>\n",
144
+ " <th>day_4</th>\n",
145
+ " <td>0</td>\n",
146
+ " <td>0</td>\n",
147
+ " <td>0</td>\n",
148
+ " </tr>\n",
149
+ " <tr>\n",
150
+ " <th>day_5</th>\n",
151
+ " <td>1</td>\n",
152
+ " <td>0</td>\n",
153
+ " <td>0</td>\n",
154
+ " </tr>\n",
155
+ " <tr>\n",
156
+ " <th>day_6</th>\n",
157
+ " <td>0</td>\n",
158
+ " <td>1</td>\n",
159
+ " <td>1</td>\n",
160
+ " </tr>\n",
161
+ " <tr>\n",
162
+ " <th>hour_1</th>\n",
163
+ " <td>0</td>\n",
164
+ " <td>0</td>\n",
165
+ " <td>0</td>\n",
166
+ " </tr>\n",
167
+ " <tr>\n",
168
+ " <th>hour_2</th>\n",
169
+ " <td>0</td>\n",
170
+ " <td>0</td>\n",
171
+ " <td>0</td>\n",
172
+ " </tr>\n",
173
+ " <tr>\n",
174
+ " <th>hour_3</th>\n",
175
+ " <td>0</td>\n",
176
+ " <td>0</td>\n",
177
+ " <td>0</td>\n",
178
+ " </tr>\n",
179
+ " <tr>\n",
180
+ " <th>hour_4</th>\n",
181
+ " <td>0</td>\n",
182
+ " <td>0</td>\n",
183
+ " <td>0</td>\n",
184
+ " </tr>\n",
185
+ " <tr>\n",
186
+ " <th>hour_5</th>\n",
187
+ " <td>0</td>\n",
188
+ " <td>0</td>\n",
189
+ " <td>0</td>\n",
190
+ " </tr>\n",
191
+ " <tr>\n",
192
+ " <th>hour_6</th>\n",
193
+ " <td>0</td>\n",
194
+ " <td>0</td>\n",
195
+ " <td>0</td>\n",
196
+ " </tr>\n",
197
+ " <tr>\n",
198
+ " <th>hour_7</th>\n",
199
+ " <td>0</td>\n",
200
+ " <td>0</td>\n",
201
+ " <td>0</td>\n",
202
+ " </tr>\n",
203
+ " <tr>\n",
204
+ " <th>hour_8</th>\n",
205
+ " <td>0</td>\n",
206
+ " <td>0</td>\n",
207
+ " <td>0</td>\n",
208
+ " </tr>\n",
209
+ " <tr>\n",
210
+ " <th>hour_9</th>\n",
211
+ " <td>0</td>\n",
212
+ " <td>0</td>\n",
213
+ " <td>0</td>\n",
214
+ " </tr>\n",
215
+ " <tr>\n",
216
+ " <th>hour_10</th>\n",
217
+ " <td>0</td>\n",
218
+ " <td>1</td>\n",
219
+ " <td>0</td>\n",
220
+ " </tr>\n",
221
+ " <tr>\n",
222
+ " <th>hour_11</th>\n",
223
+ " <td>0</td>\n",
224
+ " <td>0</td>\n",
225
+ " <td>0</td>\n",
226
+ " </tr>\n",
227
+ " <tr>\n",
228
+ " <th>hour_12</th>\n",
229
+ " <td>0</td>\n",
230
+ " <td>0</td>\n",
231
+ " <td>0</td>\n",
232
+ " </tr>\n",
233
+ " <tr>\n",
234
+ " <th>hour_13</th>\n",
235
+ " <td>0</td>\n",
236
+ " <td>0</td>\n",
237
+ " <td>0</td>\n",
238
+ " </tr>\n",
239
+ " <tr>\n",
240
+ " <th>hour_14</th>\n",
241
+ " <td>0</td>\n",
242
+ " <td>0</td>\n",
243
+ " <td>0</td>\n",
244
+ " </tr>\n",
245
+ " <tr>\n",
246
+ " <th>hour_15</th>\n",
247
+ " <td>0</td>\n",
248
+ " <td>0</td>\n",
249
+ " <td>0</td>\n",
250
+ " </tr>\n",
251
+ " <tr>\n",
252
+ " <th>hour_16</th>\n",
253
+ " <td>0</td>\n",
254
+ " <td>0</td>\n",
255
+ " <td>0</td>\n",
256
+ " </tr>\n",
257
+ " <tr>\n",
258
+ " <th>hour_17</th>\n",
259
+ " <td>0</td>\n",
260
+ " <td>0</td>\n",
261
+ " <td>0</td>\n",
262
+ " </tr>\n",
263
+ " <tr>\n",
264
+ " <th>hour_18</th>\n",
265
+ " <td>0</td>\n",
266
+ " <td>0</td>\n",
267
+ " <td>0</td>\n",
268
+ " </tr>\n",
269
+ " <tr>\n",
270
+ " <th>hour_19</th>\n",
271
+ " <td>1</td>\n",
272
+ " <td>0</td>\n",
273
+ " <td>1</td>\n",
274
+ " </tr>\n",
275
+ " <tr>\n",
276
+ " <th>hour_20</th>\n",
277
+ " <td>0</td>\n",
278
+ " <td>0</td>\n",
279
+ " <td>0</td>\n",
280
+ " </tr>\n",
281
+ " <tr>\n",
282
+ " <th>hour_21</th>\n",
283
+ " <td>0</td>\n",
284
+ " <td>0</td>\n",
285
+ " <td>0</td>\n",
286
+ " </tr>\n",
287
+ " <tr>\n",
288
+ " <th>hour_22</th>\n",
289
+ " <td>0</td>\n",
290
+ " <td>0</td>\n",
291
+ " <td>0</td>\n",
292
+ " </tr>\n",
293
+ " <tr>\n",
294
+ " <th>hour_23</th>\n",
295
+ " <td>0</td>\n",
296
+ " <td>0</td>\n",
297
+ " <td>0</td>\n",
298
+ " </tr>\n",
299
+ " <tr>\n",
300
+ " <th>type_ATROPELAMENTO</th>\n",
301
+ " <td>0</td>\n",
302
+ " <td>0</td>\n",
303
+ " <td>1</td>\n",
304
+ " </tr>\n",
305
+ " <tr>\n",
306
+ " <th>type_CHOQUE</th>\n",
307
+ " <td>0</td>\n",
308
+ " <td>0</td>\n",
309
+ " <td>0</td>\n",
310
+ " </tr>\n",
311
+ " <tr>\n",
312
+ " <th>type_COLISÃO</th>\n",
313
+ " <td>0</td>\n",
314
+ " <td>0</td>\n",
315
+ " <td>0</td>\n",
316
+ " </tr>\n",
317
+ " <tr>\n",
318
+ " <th>type_OUTROS</th>\n",
319
+ " <td>0</td>\n",
320
+ " <td>0</td>\n",
321
+ " <td>0</td>\n",
322
+ " </tr>\n",
323
+ " </tbody>\n",
324
+ "</table>\n",
325
+ "</div>"
326
+ ],
327
+ "text/plain": [
328
+ " 0 1 2\n",
329
+ "latitude -30.009614 -30.0403 -30.069\n",
330
+ "longitude -51.185581 -51.1958 -51.1437\n",
331
+ "feridos True True True\n",
332
+ "feridos_gr False False False\n",
333
+ "fatais False False False\n",
334
+ "caminhao False False False\n",
335
+ "moto True True False\n",
336
+ "cars True True True\n",
337
+ "transport False False False\n",
338
+ "others False False False\n",
339
+ "holiday False True True\n",
340
+ "day_1 0 0 0\n",
341
+ "day_2 0 0 0\n",
342
+ "day_3 0 0 0\n",
343
+ "day_4 0 0 0\n",
344
+ "day_5 1 0 0\n",
345
+ "day_6 0 1 1\n",
346
+ "hour_1 0 0 0\n",
347
+ "hour_2 0 0 0\n",
348
+ "hour_3 0 0 0\n",
349
+ "hour_4 0 0 0\n",
350
+ "hour_5 0 0 0\n",
351
+ "hour_6 0 0 0\n",
352
+ "hour_7 0 0 0\n",
353
+ "hour_8 0 0 0\n",
354
+ "hour_9 0 0 0\n",
355
+ "hour_10 0 1 0\n",
356
+ "hour_11 0 0 0\n",
357
+ "hour_12 0 0 0\n",
358
+ "hour_13 0 0 0\n",
359
+ "hour_14 0 0 0\n",
360
+ "hour_15 0 0 0\n",
361
+ "hour_16 0 0 0\n",
362
+ "hour_17 0 0 0\n",
363
+ "hour_18 0 0 0\n",
364
+ "hour_19 1 0 1\n",
365
+ "hour_20 0 0 0\n",
366
+ "hour_21 0 0 0\n",
367
+ "hour_22 0 0 0\n",
368
+ "hour_23 0 0 0\n",
369
+ "type_ATROPELAMENTO 0 0 1\n",
370
+ "type_CHOQUE 0 0 0\n",
371
+ "type_COLISÃO 0 0 0\n",
372
+ "type_OUTROS 0 0 0"
373
+ ]
374
+ },
375
+ "execution_count": 5,
376
+ "metadata": {},
377
+ "output_type": "execute_result"
378
+ }
379
+ ],
380
+ "source": [
381
+ "import os.path as path\n",
382
+ "from pandas import read_csv\n",
383
+ "\n",
384
+ "file_csv = path.abspath(\"../\")\n",
385
+ "\n",
386
+ "file_csv = path.join(file_csv, \"data\" ,\"accidents_trans.csv\")\n",
387
+ "\n",
388
+ "accidents_trans = read_csv(file_csv)\n",
389
+ "\n",
390
+ "accidents_trans.head(3).T"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "markdown",
395
+ "metadata": {},
396
+ "source": [
397
+ "# 3. Data Preparation"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": 6,
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "import joblib as jb # Use to save the model to deploy\n",
407
+ "from sklearn.preprocessing import StandardScaler\n",
408
+ "from sklearn.model_selection import train_test_split"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": 7,
414
+ "metadata": {},
415
+ "outputs": [
416
+ {
417
+ "name": "stdout",
418
+ "output_type": "stream",
419
+ "text": [
420
+ "Our model to predict the probability of feridos_gr will be create with 25497 rows and 41 features.\n"
421
+ ]
422
+ }
423
+ ],
424
+ "source": [
425
+ "outputs = [\"feridos\", \"feridos_gr\", \"fatais\"]\n",
426
+ "inputs = [col for col in accidents_trans.columns if col not in outputs]\n",
427
+ "\n",
428
+ "X = accidents_trans[inputs].copy()\n",
429
+ "Y = accidents_trans[outputs].copy()\n",
430
+ "\n",
431
+ "# Filtering data considering the output\n",
432
+ "output = \"feridos_gr\"\n",
433
+ "\n",
434
+ "if output == \"feridos_gr\":\n",
435
+ " X = X[Y[\"feridos\"]]\n",
436
+ " Y = Y.loc[Y[\"feridos\"], \"feridos_gr\"]\n",
437
+ "elif output == \"fatais\":\n",
438
+ " X = X[Y[\"feridos_gr\"]]\n",
439
+ " Y = Y.loc[Y[\"feridos_gr\"], \"fatais\"]\n",
440
+ "else:\n",
441
+ " Y = Y[\"feridos\"]\n",
442
+ "\n",
443
+ "print(f\"Our model to predict the probability of \" \\\n",
444
+ " f\"{output} will be create with {X.shape[0]} \" \\\n",
445
+ " f\"rows and {X.shape[1]} features.\")"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "code",
450
+ "execution_count": 8,
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "import csv\n",
455
+ "\n",
456
+ "with open(\"model_features.csv\", 'w') as f:\n",
457
+ " writer = csv.writer(f)\n",
458
+ " writer.writerow(X.columns)"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "markdown",
463
+ "metadata": {},
464
+ "source": [
465
+ "Considering that we will use models scaling sensitive, we will need to scale our data first. Beside this, we will need to save our scaler for future use."
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "code",
470
+ "execution_count": 9,
471
+ "metadata": {},
472
+ "outputs": [
473
+ {
474
+ "data": {
475
+ "text/plain": [
476
+ "['c:\\\\Users\\\\grego\\\\OneDrive\\\\Documentos\\\\Documentos Pessoais\\\\00_DataCamp\\\\09_VSC\\\\poa_car_accidents\\\\poa_car_accidents\\\\model\\\\scaler_feridos_gr.pkl']"
477
+ ]
478
+ },
479
+ "execution_count": 9,
480
+ "metadata": {},
481
+ "output_type": "execute_result"
482
+ }
483
+ ],
484
+ "source": [
485
+ "# Setting the random state using my luck number :-)\n",
486
+ "lucky_num = 7\n",
487
+ "\n",
488
+ "# X_train and y_train to train our model\n",
489
+ "X_train, X_test, y_train, y_test = train_test_split(\n",
490
+ " X,\n",
491
+ " Y,\n",
492
+ " test_size=0.30,\n",
493
+ " random_state=lucky_num,\n",
494
+ " shuffle=True, # Used because our data is sort by date\n",
495
+ " stratify=Y) # Used because our data is unbalanced\n",
496
+ "\n",
497
+ "# Scaling\n",
498
+ "scaler = StandardScaler()\n",
499
+ "X_train = scaler.fit_transform(X_train)\n",
500
+ "X_test = scaler.transform(X_test)\n",
501
+ "\n",
502
+ "# Saving scaler\n",
503
+ "file_name = \"scaler_\" + output + '.pkl'\n",
504
+ "jb.dump(scaler, path.join(path.abspath(\"./\"), file_name))"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "markdown",
509
+ "metadata": {},
510
+ "source": [
511
+ "# 4. Data Modeling\n",
512
+ "\n",
513
+ "We will create and use cross-validation to evaluate the following models:\n",
514
+ "\n",
515
+ "- Logistic Regression;\n",
516
+ "\n",
517
+ "- Gaussian Naive Bayes;\n",
518
+ "\n",
519
+ "- K Neighbors;\n",
520
+ "\n",
521
+ "- Random Forest;\n",
522
+ "\n",
523
+ "- Gradient Boosting; and,\n",
524
+ "\n",
525
+ "- XGBoost.\n",
526
+ "\n",
527
+ "We will use two scores to select and evaluate our models:\n",
528
+ "\n",
529
+ "- F1 score: composition between the precision (how much our model correct classify every true label) and recall (how moch our model correct indicate true labels); and,\n",
530
+ "\n",
531
+ "- Brier score: average between the correct and the predict probability.\n",
532
+ "\n",
533
+ "However, we will see other metrics to support our decision:\n",
534
+ "\n",
535
+ "- Accurancy;\n",
536
+ "\n",
537
+ "- ROC_AOC; and,\n",
538
+ "\n",
539
+ "- Log loss (an other way to quantify the quality of probability predictions).\n",
540
+ "\n",
541
+ "And, before you go, we will find for each model if there is a hyperparameter to deal with the unbalanced output."
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": 10,
547
+ "metadata": {},
548
+ "outputs": [],
549
+ "source": [
550
+ "import pandas as pd\n",
551
+ "import xgboost as xgb\n",
552
+ "from sklearn.naive_bayes import GaussianNB\n",
553
+ "from sklearn.neighbors import KNeighborsClassifier\n",
554
+ "from sklearn.linear_model import LogisticRegression\n",
555
+ "from sklearn.model_selection import cross_validate \n",
556
+ "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier\n",
557
+ "from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_auc_score, f1_score, brier_score_loss, log_loss\n",
558
+ "\n",
559
+ "scores = [\"accuracy\", \"f1\", \"precision\", \"recall\", \"roc_auc\", \"neg_brier_score\",\"neg_log_loss\"]"
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "code",
564
+ "execution_count": 11,
565
+ "metadata": {},
566
+ "outputs": [],
567
+ "source": [
568
+ "def eval_model(cls) -> tuple:\n",
569
+ " \"\"\"This function will calculate the metrics\n",
570
+ " to evaluate a classification model.\n",
571
+ " \"\"\"\n",
572
+ " # Predicting labels and probabilities\n",
573
+ " y_pred = cls.predict(X_test)\n",
574
+ " y_prob = cls.predict_proba(X_test)[:,1]\n",
575
+ "\n",
576
+ " # Calculating scores\n",
577
+ " accurancy = accuracy_score(y_test, y_pred)\n",
578
+ " f1 = f1_score(y_test, y_pred)\n",
579
+ " recall = recall_score(y_test, y_pred)\n",
580
+ " precision = precision_score(y_test, y_pred)\n",
581
+ " roc_auc = roc_auc_score(y_test, y_prob) # https://datascience.stackexchange.com/questions/114394/does-roc-auc-different-between-crossval-and-test-set-indicate-overfitting-or-oth\n",
582
+ " brier_score = brier_score_loss(y_test, y_prob)\n",
583
+ " log_loss_value = log_loss(y_test, y_prob)\n",
584
+ "\n",
585
+ " return accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value\n",
586
+ "\n",
587
+ "def create_model(name: str, cls) -> list:\n",
588
+ " \"\"\"This function will create some models\n",
589
+ " and return scores to evaluate it.\"\"\"\n",
590
+ " # Ftting model\n",
591
+ " cls.fit(X_train, y_train)\n",
592
+ "\n",
593
+ " # Using cross-validation to evaluate the model fitted\n",
594
+ " cls_cross = cross_validate(\n",
595
+ " estimator=cls,\n",
596
+ " X=X_train,\n",
597
+ " y=y_train,\n",
598
+ " cv=5,\n",
599
+ " scoring=scores)\n",
600
+ "\n",
601
+ " df_cv = pd.DataFrame.from_dict(cls_cross, orient='index', columns=[\"CV\"+str(i) for i in range(1,6)])\n",
602
+ "\n",
603
+ " # Calculating score to test set\n",
604
+ " accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value = eval_model(cls)\n",
605
+ "\n",
606
+ " # Filling a dataframe to better presentation\n",
607
+ " df_cv.at[\"test_accuracy\", \"TestSet\"] = accurancy\n",
608
+ " df_cv.at[\"test_f1\", \"TestSet\"] = f1\n",
609
+ " df_cv.at[\"test_recall\", \"TestSet\"] = recall\n",
610
+ " df_cv.at[\"test_precision\", \"TestSet\"] = precision\n",
611
+ " df_cv.at[\"test_roc_auc\", \"TestSet\"] = roc_auc\n",
612
+ " df_cv.at[\"test_neg_brier_score\", \"TestSet\"] = -brier_score\n",
613
+ " df_cv.at[\"test_neg_log_loss\", \"TestSet\"] = -log_loss_value\n",
614
+ "\n",
615
+ " caption = f\"{name} Validation Scores\"\n",
616
+ "\n",
617
+ " display(df_cv.style.set_caption(caption))\n",
618
+ "\n",
619
+ " return [accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value]"
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "code",
624
+ "execution_count": 15,
625
+ "metadata": {},
626
+ "outputs": [
627
+ {
628
+ "data": {
629
+ "text/html": [
630
+ "<style type=\"text/css\">\n",
631
+ "</style>\n",
632
+ "<table id=\"T_440e0\">\n",
633
+ " <caption>LR Validation Scores</caption>\n",
634
+ " <thead>\n",
635
+ " <tr>\n",
636
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
637
+ " <th id=\"T_440e0_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
638
+ " <th id=\"T_440e0_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
639
+ " <th id=\"T_440e0_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
640
+ " <th id=\"T_440e0_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
641
+ " <th id=\"T_440e0_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
642
+ " <th id=\"T_440e0_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
643
+ " </tr>\n",
644
+ " </thead>\n",
645
+ " <tbody>\n",
646
+ " <tr>\n",
647
+ " <th id=\"T_440e0_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
648
+ " <td id=\"T_440e0_row0_col0\" class=\"data row0 col0\" >0.037210</td>\n",
649
+ " <td id=\"T_440e0_row0_col1\" class=\"data row0 col1\" >0.031787</td>\n",
650
+ " <td id=\"T_440e0_row0_col2\" class=\"data row0 col2\" >0.031054</td>\n",
651
+ " <td id=\"T_440e0_row0_col3\" class=\"data row0 col3\" >0.030418</td>\n",
652
+ " <td id=\"T_440e0_row0_col4\" class=\"data row0 col4\" >0.031932</td>\n",
653
+ " <td id=\"T_440e0_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
654
+ " </tr>\n",
655
+ " <tr>\n",
656
+ " <th id=\"T_440e0_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
657
+ " <td id=\"T_440e0_row1_col0\" class=\"data row1 col0\" >0.000000</td>\n",
658
+ " <td id=\"T_440e0_row1_col1\" class=\"data row1 col1\" >0.008581</td>\n",
659
+ " <td id=\"T_440e0_row1_col2\" class=\"data row1 col2\" >0.008867</td>\n",
660
+ " <td id=\"T_440e0_row1_col3\" class=\"data row1 col3\" >0.009067</td>\n",
661
+ " <td id=\"T_440e0_row1_col4\" class=\"data row1 col4\" >0.008065</td>\n",
662
+ " <td id=\"T_440e0_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
663
+ " </tr>\n",
664
+ " <tr>\n",
665
+ " <th id=\"T_440e0_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
666
+ " <td id=\"T_440e0_row2_col0\" class=\"data row2 col0\" >0.676751</td>\n",
667
+ " <td id=\"T_440e0_row2_col1\" class=\"data row2 col1\" >0.650420</td>\n",
668
+ " <td id=\"T_440e0_row2_col2\" class=\"data row2 col2\" >0.651723</td>\n",
669
+ " <td id=\"T_440e0_row2_col3\" class=\"data row2 col3\" >0.662370</td>\n",
670
+ " <td id=\"T_440e0_row2_col4\" class=\"data row2 col4\" >0.641636</td>\n",
671
+ " <td id=\"T_440e0_row2_col5\" class=\"data row2 col5\" >0.660523</td>\n",
672
+ " </tr>\n",
673
+ " <tr>\n",
674
+ " <th id=\"T_440e0_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
675
+ " <td id=\"T_440e0_row3_col0\" class=\"data row3 col0\" >0.437622</td>\n",
676
+ " <td id=\"T_440e0_row3_col1\" class=\"data row3 col1\" >0.422222</td>\n",
677
+ " <td id=\"T_440e0_row3_col2\" class=\"data row3 col2\" >0.413403</td>\n",
678
+ " <td id=\"T_440e0_row3_col3\" class=\"data row3 col3\" >0.424271</td>\n",
679
+ " <td id=\"T_440e0_row3_col4\" class=\"data row3 col4\" >0.401497</td>\n",
680
+ " <td id=\"T_440e0_row3_col5\" class=\"data row3 col5\" >0.414431</td>\n",
681
+ " </tr>\n",
682
+ " <tr>\n",
683
+ " <th id=\"T_440e0_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
684
+ " <td id=\"T_440e0_row4_col0\" class=\"data row4 col0\" >0.352433</td>\n",
685
+ " <td id=\"T_440e0_row4_col1\" class=\"data row4 col1\" >0.329957</td>\n",
686
+ " <td id=\"T_440e0_row4_col2\" class=\"data row4 col2\" >0.326379</td>\n",
687
+ " <td id=\"T_440e0_row4_col3\" class=\"data row4 col3\" >0.337386</td>\n",
688
+ " <td id=\"T_440e0_row4_col4\" class=\"data row4 col4\" >0.315673</td>\n",
689
+ " <td id=\"T_440e0_row4_col5\" class=\"data row4 col5\" >0.331889</td>\n",
690
+ " </tr>\n",
691
+ " <tr>\n",
692
+ " <th id=\"T_440e0_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
693
+ " <td id=\"T_440e0_row5_col0\" class=\"data row5 col0\" >0.577121</td>\n",
694
+ " <td id=\"T_440e0_row5_col1\" class=\"data row5 col1\" >0.586118</td>\n",
695
+ " <td id=\"T_440e0_row5_col2\" class=\"data row5 col2\" >0.563707</td>\n",
696
+ " <td id=\"T_440e0_row5_col3\" class=\"data row5 col3\" >0.571429</td>\n",
697
+ " <td id=\"T_440e0_row5_col4\" class=\"data row5 col4\" >0.551414</td>\n",
698
+ " <td id=\"T_440e0_row5_col5\" class=\"data row5 col5\" >0.551621</td>\n",
699
+ " </tr>\n",
700
+ " <tr>\n",
701
+ " <th id=\"T_440e0_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
702
+ " <td id=\"T_440e0_row6_col0\" class=\"data row6 col0\" >0.688461</td>\n",
703
+ " <td id=\"T_440e0_row6_col1\" class=\"data row6 col1\" >0.667854</td>\n",
704
+ " <td id=\"T_440e0_row6_col2\" class=\"data row6 col2\" >0.663118</td>\n",
705
+ " <td id=\"T_440e0_row6_col3\" class=\"data row6 col3\" >0.671007</td>\n",
706
+ " <td id=\"T_440e0_row6_col4\" class=\"data row6 col4\" >0.644010</td>\n",
707
+ " <td id=\"T_440e0_row6_col5\" class=\"data row6 col5\" >0.667540</td>\n",
708
+ " </tr>\n",
709
+ " <tr>\n",
710
+ " <th id=\"T_440e0_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
711
+ " <td id=\"T_440e0_row7_col0\" class=\"data row7 col0\" >-0.221720</td>\n",
712
+ " <td id=\"T_440e0_row7_col1\" class=\"data row7 col1\" >-0.228903</td>\n",
713
+ " <td id=\"T_440e0_row7_col2\" class=\"data row7 col2\" >-0.228032</td>\n",
714
+ " <td id=\"T_440e0_row7_col3\" class=\"data row7 col3\" >-0.226401</td>\n",
715
+ " <td id=\"T_440e0_row7_col4\" class=\"data row7 col4\" >-0.231747</td>\n",
716
+ " <td id=\"T_440e0_row7_col5\" class=\"data row7 col5\" >-0.225709</td>\n",
717
+ " </tr>\n",
718
+ " <tr>\n",
719
+ " <th id=\"T_440e0_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
720
+ " <td id=\"T_440e0_row8_col0\" class=\"data row8 col0\" >-0.635799</td>\n",
721
+ " <td id=\"T_440e0_row8_col1\" class=\"data row8 col1\" >-0.651183</td>\n",
722
+ " <td id=\"T_440e0_row8_col2\" class=\"data row8 col2\" >-0.649342</td>\n",
723
+ " <td id=\"T_440e0_row8_col3\" class=\"data row8 col3\" >-0.646086</td>\n",
724
+ " <td id=\"T_440e0_row8_col4\" class=\"data row8 col4\" >-0.657217</td>\n",
725
+ " <td id=\"T_440e0_row8_col5\" class=\"data row8 col5\" >-0.644816</td>\n",
726
+ " </tr>\n",
727
+ " </tbody>\n",
728
+ "</table>\n"
729
+ ],
730
+ "text/plain": [
731
+ "<pandas.io.formats.style.Styler at 0x21e08e87ca0>"
732
+ ]
733
+ },
734
+ "metadata": {},
735
+ "output_type": "display_data"
736
+ },
737
+ {
738
+ "data": {
739
+ "text/html": [
740
+ "<style type=\"text/css\">\n",
741
+ "</style>\n",
742
+ "<table id=\"T_0ee79\">\n",
743
+ " <caption>NB Validation Scores</caption>\n",
744
+ " <thead>\n",
745
+ " <tr>\n",
746
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
747
+ " <th id=\"T_0ee79_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
748
+ " <th id=\"T_0ee79_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
749
+ " <th id=\"T_0ee79_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
750
+ " <th id=\"T_0ee79_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
751
+ " <th id=\"T_0ee79_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
752
+ " <th id=\"T_0ee79_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
753
+ " </tr>\n",
754
+ " </thead>\n",
755
+ " <tbody>\n",
756
+ " <tr>\n",
757
+ " <th id=\"T_0ee79_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
758
+ " <td id=\"T_0ee79_row0_col0\" class=\"data row0 col0\" >0.014398</td>\n",
759
+ " <td id=\"T_0ee79_row0_col1\" class=\"data row0 col1\" >0.009557</td>\n",
760
+ " <td id=\"T_0ee79_row0_col2\" class=\"data row0 col2\" >0.006821</td>\n",
761
+ " <td id=\"T_0ee79_row0_col3\" class=\"data row0 col3\" >0.007424</td>\n",
762
+ " <td id=\"T_0ee79_row0_col4\" class=\"data row0 col4\" >0.012280</td>\n",
763
+ " <td id=\"T_0ee79_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
764
+ " </tr>\n",
765
+ " <tr>\n",
766
+ " <th id=\"T_0ee79_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
767
+ " <td id=\"T_0ee79_row1_col0\" class=\"data row1 col0\" >0.011809</td>\n",
768
+ " <td id=\"T_0ee79_row1_col1\" class=\"data row1 col1\" >0.011134</td>\n",
769
+ " <td id=\"T_0ee79_row1_col2\" class=\"data row1 col2\" >0.017092</td>\n",
770
+ " <td id=\"T_0ee79_row1_col3\" class=\"data row1 col3\" >0.011901</td>\n",
771
+ " <td id=\"T_0ee79_row1_col4\" class=\"data row1 col4\" >0.011843</td>\n",
772
+ " <td id=\"T_0ee79_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
773
+ " </tr>\n",
774
+ " <tr>\n",
775
+ " <th id=\"T_0ee79_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
776
+ " <td id=\"T_0ee79_row2_col0\" class=\"data row2 col0\" >0.713725</td>\n",
777
+ " <td id=\"T_0ee79_row2_col1\" class=\"data row2 col1\" >0.696639</td>\n",
778
+ " <td id=\"T_0ee79_row2_col2\" class=\"data row2 col2\" >0.695433</td>\n",
779
+ " <td id=\"T_0ee79_row2_col3\" class=\"data row2 col3\" >0.695153</td>\n",
780
+ " <td id=\"T_0ee79_row2_col4\" class=\"data row2 col4\" >0.678902</td>\n",
781
+ " <td id=\"T_0ee79_row2_col5\" class=\"data row2 col5\" >0.699608</td>\n",
782
+ " </tr>\n",
783
+ " <tr>\n",
784
+ " <th id=\"T_0ee79_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
785
+ " <td id=\"T_0ee79_row3_col0\" class=\"data row3 col0\" >0.398115</td>\n",
786
+ " <td id=\"T_0ee79_row3_col1\" class=\"data row3 col1\" >0.385706</td>\n",
787
+ " <td id=\"T_0ee79_row3_col2\" class=\"data row3 col2\" >0.386222</td>\n",
788
+ " <td id=\"T_0ee79_row3_col3\" class=\"data row3 col3\" >0.373993</td>\n",
789
+ " <td id=\"T_0ee79_row3_col4\" class=\"data row3 col4\" >0.358343</td>\n",
790
+ " <td id=\"T_0ee79_row3_col5\" class=\"data row3 col5\" >0.381260</td>\n",
791
+ " </tr>\n",
792
+ " <tr>\n",
793
+ " <th id=\"T_0ee79_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
794
+ " <td id=\"T_0ee79_row4_col0\" class=\"data row4 col0\" >0.367391</td>\n",
795
+ " <td id=\"T_0ee79_row4_col1\" class=\"data row4 col1\" >0.345178</td>\n",
796
+ " <td id=\"T_0ee79_row4_col2\" class=\"data row4 col2\" >0.344064</td>\n",
797
+ " <td id=\"T_0ee79_row4_col3\" class=\"data row4 col3\" >0.338189</td>\n",
798
+ " <td id=\"T_0ee79_row4_col4\" class=\"data row4 col4\" >0.317460</td>\n",
799
+ " <td id=\"T_0ee79_row4_col5\" class=\"data row4 col5\" >0.345703</td>\n",
800
+ " </tr>\n",
801
+ " <tr>\n",
802
+ " <th id=\"T_0ee79_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
803
+ " <td id=\"T_0ee79_row5_col0\" class=\"data row5 col0\" >0.434447</td>\n",
804
+ " <td id=\"T_0ee79_row5_col1\" class=\"data row5 col1\" >0.437018</td>\n",
805
+ " <td id=\"T_0ee79_row5_col2\" class=\"data row5 col2\" >0.440154</td>\n",
806
+ " <td id=\"T_0ee79_row5_col3\" class=\"data row5 col3\" >0.418275</td>\n",
807
+ " <td id=\"T_0ee79_row5_col4\" class=\"data row5 col4\" >0.411311</td>\n",
808
+ " <td id=\"T_0ee79_row5_col5\" class=\"data row5 col5\" >0.424970</td>\n",
809
+ " </tr>\n",
810
+ " <tr>\n",
811
+ " <th id=\"T_0ee79_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
812
+ " <td id=\"T_0ee79_row6_col0\" class=\"data row6 col0\" >0.658495</td>\n",
813
+ " <td id=\"T_0ee79_row6_col1\" class=\"data row6 col1\" >0.637091</td>\n",
814
+ " <td id=\"T_0ee79_row6_col2\" class=\"data row6 col2\" >0.633805</td>\n",
815
+ " <td id=\"T_0ee79_row6_col3\" class=\"data row6 col3\" >0.633238</td>\n",
816
+ " <td id=\"T_0ee79_row6_col4\" class=\"data row6 col4\" >0.609990</td>\n",
817
+ " <td id=\"T_0ee79_row6_col5\" class=\"data row6 col5\" >0.634031</td>\n",
818
+ " </tr>\n",
819
+ " <tr>\n",
820
+ " <th id=\"T_0ee79_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
821
+ " <td id=\"T_0ee79_row7_col0\" class=\"data row7 col0\" >-0.251232</td>\n",
822
+ " <td id=\"T_0ee79_row7_col1\" class=\"data row7 col1\" >-0.260890</td>\n",
823
+ " <td id=\"T_0ee79_row7_col2\" class=\"data row7 col2\" >-0.271004</td>\n",
824
+ " <td id=\"T_0ee79_row7_col3\" class=\"data row7 col3\" >-0.273609</td>\n",
825
+ " <td id=\"T_0ee79_row7_col4\" class=\"data row7 col4\" >-0.285054</td>\n",
826
+ " <td id=\"T_0ee79_row7_col5\" class=\"data row7 col5\" >-0.268468</td>\n",
827
+ " </tr>\n",
828
+ " <tr>\n",
829
+ " <th id=\"T_0ee79_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
830
+ " <td id=\"T_0ee79_row8_col0\" class=\"data row8 col0\" >-1.412893</td>\n",
831
+ " <td id=\"T_0ee79_row8_col1\" class=\"data row8 col1\" >-1.627295</td>\n",
832
+ " <td id=\"T_0ee79_row8_col2\" class=\"data row8 col2\" >-1.745289</td>\n",
833
+ " <td id=\"T_0ee79_row8_col3\" class=\"data row8 col3\" >-1.752608</td>\n",
834
+ " <td id=\"T_0ee79_row8_col4\" class=\"data row8 col4\" >-1.950351</td>\n",
835
+ " <td id=\"T_0ee79_row8_col5\" class=\"data row8 col5\" >-1.659029</td>\n",
836
+ " </tr>\n",
837
+ " </tbody>\n",
838
+ "</table>\n"
839
+ ],
840
+ "text/plain": [
841
+ "<pandas.io.formats.style.Styler at 0x21e09deee00>"
842
+ ]
843
+ },
844
+ "metadata": {},
845
+ "output_type": "display_data"
846
+ },
847
+ {
848
+ "data": {
849
+ "text/html": [
850
+ "<style type=\"text/css\">\n",
851
+ "</style>\n",
852
+ "<table id=\"T_4933a\">\n",
853
+ " <caption>KNN Validation Scores</caption>\n",
854
+ " <thead>\n",
855
+ " <tr>\n",
856
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
857
+ " <th id=\"T_4933a_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
858
+ " <th id=\"T_4933a_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
859
+ " <th id=\"T_4933a_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
860
+ " <th id=\"T_4933a_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
861
+ " <th id=\"T_4933a_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
862
+ " <th id=\"T_4933a_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
863
+ " </tr>\n",
864
+ " </thead>\n",
865
+ " <tbody>\n",
866
+ " <tr>\n",
867
+ " <th id=\"T_4933a_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
868
+ " <td id=\"T_4933a_row0_col0\" class=\"data row0 col0\" >0.000000</td>\n",
869
+ " <td id=\"T_4933a_row0_col1\" class=\"data row0 col1\" >0.002192</td>\n",
870
+ " <td id=\"T_4933a_row0_col2\" class=\"data row0 col2\" >0.004006</td>\n",
871
+ " <td id=\"T_4933a_row0_col3\" class=\"data row0 col3\" >0.004514</td>\n",
872
+ " <td id=\"T_4933a_row0_col4\" class=\"data row0 col4\" >0.006648</td>\n",
873
+ " <td id=\"T_4933a_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
874
+ " </tr>\n",
875
+ " <tr>\n",
876
+ " <th id=\"T_4933a_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
877
+ " <td id=\"T_4933a_row1_col0\" class=\"data row1 col0\" >0.241018</td>\n",
878
+ " <td id=\"T_4933a_row1_col1\" class=\"data row1 col1\" >0.238426</td>\n",
879
+ " <td id=\"T_4933a_row1_col2\" class=\"data row1 col2\" >0.234827</td>\n",
880
+ " <td id=\"T_4933a_row1_col3\" class=\"data row1 col3\" >0.250730</td>\n",
881
+ " <td id=\"T_4933a_row1_col4\" class=\"data row1 col4\" >0.503909</td>\n",
882
+ " <td id=\"T_4933a_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
883
+ " </tr>\n",
884
+ " <tr>\n",
885
+ " <th id=\"T_4933a_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
886
+ " <td id=\"T_4933a_row2_col0\" class=\"data row2 col0\" >0.749860</td>\n",
887
+ " <td id=\"T_4933a_row2_col1\" class=\"data row2 col1\" >0.757703</td>\n",
888
+ " <td id=\"T_4933a_row2_col2\" class=\"data row2 col2\" >0.750911</td>\n",
889
+ " <td id=\"T_4933a_row2_col3\" class=\"data row2 col3\" >0.753713</td>\n",
890
+ " <td id=\"T_4933a_row2_col4\" class=\"data row2 col4\" >0.746708</td>\n",
891
+ " <td id=\"T_4933a_row2_col5\" class=\"data row2 col5\" >0.754379</td>\n",
892
+ " </tr>\n",
893
+ " <tr>\n",
894
+ " <th id=\"T_4933a_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
895
+ " <td id=\"T_4933a_row3_col0\" class=\"data row3 col0\" >0.203390</td>\n",
896
+ " <td id=\"T_4933a_row3_col1\" class=\"data row3 col1\" >0.214351</td>\n",
897
+ " <td id=\"T_4933a_row3_col2\" class=\"data row3 col2\" >0.201258</td>\n",
898
+ " <td id=\"T_4933a_row3_col3\" class=\"data row3 col3\" >0.208821</td>\n",
899
+ " <td id=\"T_4933a_row3_col4\" class=\"data row3 col4\" >0.208406</td>\n",
900
+ " <td id=\"T_4933a_row3_col5\" class=\"data row3 col5\" >0.214136</td>\n",
901
+ " </tr>\n",
902
+ " <tr>\n",
903
+ " <th id=\"T_4933a_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
904
+ " <td id=\"T_4933a_row4_col0\" class=\"data row4 col0\" >0.332362</td>\n",
905
+ " <td id=\"T_4933a_row4_col1\" class=\"data row4 col1\" >0.365325</td>\n",
906
+ " <td id=\"T_4933a_row4_col2\" class=\"data row4 col2\" >0.333333</td>\n",
907
+ " <td id=\"T_4933a_row4_col3\" class=\"data row4 col3\" >0.347305</td>\n",
908
+ " <td id=\"T_4933a_row4_col4\" class=\"data row4 col4\" >0.326923</td>\n",
909
+ " <td id=\"T_4933a_row4_col5\" class=\"data row4 col5\" >0.353103</td>\n",
910
+ " </tr>\n",
911
+ " <tr>\n",
912
+ " <th id=\"T_4933a_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
913
+ " <td id=\"T_4933a_row5_col0\" class=\"data row5 col0\" >0.146530</td>\n",
914
+ " <td id=\"T_4933a_row5_col1\" class=\"data row5 col1\" >0.151671</td>\n",
915
+ " <td id=\"T_4933a_row5_col2\" class=\"data row5 col2\" >0.144144</td>\n",
916
+ " <td id=\"T_4933a_row5_col3\" class=\"data row5 col3\" >0.149292</td>\n",
917
+ " <td id=\"T_4933a_row5_col4\" class=\"data row5 col4\" >0.152956</td>\n",
918
+ " <td id=\"T_4933a_row5_col5\" class=\"data row5 col5\" >0.153661</td>\n",
919
+ " </tr>\n",
920
+ " <tr>\n",
921
+ " <th id=\"T_4933a_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
922
+ " <td id=\"T_4933a_row6_col0\" class=\"data row6 col0\" >0.583018</td>\n",
923
+ " <td id=\"T_4933a_row6_col1\" class=\"data row6 col1\" >0.588834</td>\n",
924
+ " <td id=\"T_4933a_row6_col2\" class=\"data row6 col2\" >0.575426</td>\n",
925
+ " <td id=\"T_4933a_row6_col3\" class=\"data row6 col3\" >0.575821</td>\n",
926
+ " <td id=\"T_4933a_row6_col4\" class=\"data row6 col4\" >0.573316</td>\n",
927
+ " <td id=\"T_4933a_row6_col5\" class=\"data row6 col5\" >0.582555</td>\n",
928
+ " </tr>\n",
929
+ " <tr>\n",
930
+ " <th id=\"T_4933a_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
931
+ " <td id=\"T_4933a_row7_col0\" class=\"data row7 col0\" >-0.190980</td>\n",
932
+ " <td id=\"T_4933a_row7_col1\" class=\"data row7 col1\" >-0.186599</td>\n",
933
+ " <td id=\"T_4933a_row7_col2\" class=\"data row7 col2\" >-0.191325</td>\n",
934
+ " <td id=\"T_4933a_row7_col3\" class=\"data row7 col3\" >-0.190608</td>\n",
935
+ " <td id=\"T_4933a_row7_col4\" class=\"data row7 col4\" >-0.193769</td>\n",
936
+ " <td id=\"T_4933a_row7_col5\" class=\"data row7 col5\" >-0.189077</td>\n",
937
+ " </tr>\n",
938
+ " <tr>\n",
939
+ " <th id=\"T_4933a_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
940
+ " <td id=\"T_4933a_row8_col0\" class=\"data row8 col0\" >-2.500209</td>\n",
941
+ " <td id=\"T_4933a_row8_col1\" class=\"data row8 col1\" >-2.097705</td>\n",
942
+ " <td id=\"T_4933a_row8_col2\" class=\"data row8 col2\" >-2.376728</td>\n",
943
+ " <td id=\"T_4933a_row8_col3\" class=\"data row8 col3\" >-2.528142</td>\n",
944
+ " <td id=\"T_4933a_row8_col4\" class=\"data row8 col4\" >-2.570247</td>\n",
945
+ " <td id=\"T_4933a_row8_col5\" class=\"data row8 col5\" >-2.428060</td>\n",
946
+ " </tr>\n",
947
+ " </tbody>\n",
948
+ "</table>\n"
949
+ ],
950
+ "text/plain": [
951
+ "<pandas.io.formats.style.Styler at 0x21e093db520>"
952
+ ]
953
+ },
954
+ "metadata": {},
955
+ "output_type": "display_data"
956
+ },
957
+ {
958
+ "data": {
959
+ "text/html": [
960
+ "<style type=\"text/css\">\n",
961
+ "</style>\n",
962
+ "<table id=\"T_a0470\">\n",
963
+ " <caption>RF Validation Scores</caption>\n",
964
+ " <thead>\n",
965
+ " <tr>\n",
966
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
967
+ " <th id=\"T_a0470_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
968
+ " <th id=\"T_a0470_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
969
+ " <th id=\"T_a0470_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
970
+ " <th id=\"T_a0470_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
971
+ " <th id=\"T_a0470_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
972
+ " <th id=\"T_a0470_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
973
+ " </tr>\n",
974
+ " </thead>\n",
975
+ " <tbody>\n",
976
+ " <tr>\n",
977
+ " <th id=\"T_a0470_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
978
+ " <td id=\"T_a0470_row0_col0\" class=\"data row0 col0\" >1.522204</td>\n",
979
+ " <td id=\"T_a0470_row0_col1\" class=\"data row0 col1\" >1.489951</td>\n",
980
+ " <td id=\"T_a0470_row0_col2\" class=\"data row0 col2\" >1.485231</td>\n",
981
+ " <td id=\"T_a0470_row0_col3\" class=\"data row0 col3\" >1.517533</td>\n",
982
+ " <td id=\"T_a0470_row0_col4\" class=\"data row0 col4\" >1.491071</td>\n",
983
+ " <td id=\"T_a0470_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
984
+ " </tr>\n",
985
+ " <tr>\n",
986
+ " <th id=\"T_a0470_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
987
+ " <td id=\"T_a0470_row1_col0\" class=\"data row1 col0\" >0.152138</td>\n",
988
+ " <td id=\"T_a0470_row1_col1\" class=\"data row1 col1\" >0.145954</td>\n",
989
+ " <td id=\"T_a0470_row1_col2\" class=\"data row1 col2\" >0.156016</td>\n",
990
+ " <td id=\"T_a0470_row1_col3\" class=\"data row1 col3\" >0.150085</td>\n",
991
+ " <td id=\"T_a0470_row1_col4\" class=\"data row1 col4\" >0.141980</td>\n",
992
+ " <td id=\"T_a0470_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
993
+ " </tr>\n",
994
+ " <tr>\n",
995
+ " <th id=\"T_a0470_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
996
+ " <td id=\"T_a0470_row2_col0\" class=\"data row2 col0\" >0.738655</td>\n",
997
+ " <td id=\"T_a0470_row2_col1\" class=\"data row2 col1\" >0.746779</td>\n",
998
+ " <td id=\"T_a0470_row2_col2\" class=\"data row2 col2\" >0.740544</td>\n",
999
+ " <td id=\"T_a0470_row2_col3\" class=\"data row2 col3\" >0.734099</td>\n",
1000
+ " <td id=\"T_a0470_row2_col4\" class=\"data row2 col4\" >0.737461</td>\n",
1001
+ " <td id=\"T_a0470_row2_col5\" class=\"data row2 col5\" >0.733595</td>\n",
1002
+ " </tr>\n",
1003
+ " <tr>\n",
1004
+ " <th id=\"T_a0470_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
1005
+ " <td id=\"T_a0470_row3_col0\" class=\"data row3 col0\" >0.240846</td>\n",
1006
+ " <td id=\"T_a0470_row3_col1\" class=\"data row3 col1\" >0.262643</td>\n",
1007
+ " <td id=\"T_a0470_row3_col2\" class=\"data row3 col2\" >0.227045</td>\n",
1008
+ " <td id=\"T_a0470_row3_col3\" class=\"data row3 col3\" >0.203191</td>\n",
1009
+ " <td id=\"T_a0470_row3_col4\" class=\"data row3 col4\" >0.244964</td>\n",
1010
+ " <td id=\"T_a0470_row3_col5\" class=\"data row3 col5\" >0.242379</td>\n",
1011
+ " </tr>\n",
1012
+ " <tr>\n",
1013
+ " <th id=\"T_a0470_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
1014
+ " <td id=\"T_a0470_row4_col0\" class=\"data row4 col0\" >0.328160</td>\n",
1015
+ " <td id=\"T_a0470_row4_col1\" class=\"data row4 col1\" >0.359375</td>\n",
1016
+ " <td id=\"T_a0470_row4_col2\" class=\"data row4 col2\" >0.323040</td>\n",
1017
+ " <td id=\"T_a0470_row4_col3\" class=\"data row4 col3\" >0.292271</td>\n",
1018
+ " <td id=\"T_a0470_row4_col4\" class=\"data row4 col4\" >0.328294</td>\n",
1019
+ " <td id=\"T_a0470_row4_col5\" class=\"data row4 col5\" >0.318359</td>\n",
1020
+ " </tr>\n",
1021
+ " <tr>\n",
1022
+ " <th id=\"T_a0470_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
1023
+ " <td id=\"T_a0470_row5_col0\" class=\"data row5 col0\" >0.190231</td>\n",
1024
+ " <td id=\"T_a0470_row5_col1\" class=\"data row5 col1\" >0.206941</td>\n",
1025
+ " <td id=\"T_a0470_row5_col2\" class=\"data row5 col2\" >0.175032</td>\n",
1026
+ " <td id=\"T_a0470_row5_col3\" class=\"data row5 col3\" >0.155727</td>\n",
1027
+ " <td id=\"T_a0470_row5_col4\" class=\"data row5 col4\" >0.195373</td>\n",
1028
+ " <td id=\"T_a0470_row5_col5\" class=\"data row5 col5\" >0.195678</td>\n",
1029
+ " </tr>\n",
1030
+ " <tr>\n",
1031
+ " <th id=\"T_a0470_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
1032
+ " <td id=\"T_a0470_row6_col0\" class=\"data row6 col0\" >0.616397</td>\n",
1033
+ " <td id=\"T_a0470_row6_col1\" class=\"data row6 col1\" >0.614756</td>\n",
1034
+ " <td id=\"T_a0470_row6_col2\" class=\"data row6 col2\" >0.597688</td>\n",
1035
+ " <td id=\"T_a0470_row6_col3\" class=\"data row6 col3\" >0.602290</td>\n",
1036
+ " <td id=\"T_a0470_row6_col4\" class=\"data row6 col4\" >0.596575</td>\n",
1037
+ " <td id=\"T_a0470_row6_col5\" class=\"data row6 col5\" >0.601436</td>\n",
1038
+ " </tr>\n",
1039
+ " <tr>\n",
1040
+ " <th id=\"T_a0470_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
1041
+ " <td id=\"T_a0470_row7_col0\" class=\"data row7 col0\" >-0.184849</td>\n",
1042
+ " <td id=\"T_a0470_row7_col1\" class=\"data row7 col1\" >-0.182776</td>\n",
1043
+ " <td id=\"T_a0470_row7_col2\" class=\"data row7 col2\" >-0.186704</td>\n",
1044
+ " <td id=\"T_a0470_row7_col3\" class=\"data row7 col3\" >-0.187369</td>\n",
1045
+ " <td id=\"T_a0470_row7_col4\" class=\"data row7 col4\" >-0.189693</td>\n",
1046
+ " <td id=\"T_a0470_row7_col5\" class=\"data row7 col5\" >-0.188506</td>\n",
1047
+ " </tr>\n",
1048
+ " <tr>\n",
1049
+ " <th id=\"T_a0470_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
1050
+ " <td id=\"T_a0470_row8_col0\" class=\"data row8 col0\" >-0.711816</td>\n",
1051
+ " <td id=\"T_a0470_row8_col1\" class=\"data row8 col1\" >-0.673727</td>\n",
1052
+ " <td id=\"T_a0470_row8_col2\" class=\"data row8 col2\" >-0.766073</td>\n",
1053
+ " <td id=\"T_a0470_row8_col3\" class=\"data row8 col3\" >-0.719537</td>\n",
1054
+ " <td id=\"T_a0470_row8_col4\" class=\"data row8 col4\" >-0.775028</td>\n",
1055
+ " <td id=\"T_a0470_row8_col5\" class=\"data row8 col5\" >-0.743245</td>\n",
1056
+ " </tr>\n",
1057
+ " </tbody>\n",
1058
+ "</table>\n"
1059
+ ],
1060
+ "text/plain": [
1061
+ "<pandas.io.formats.style.Styler at 0x21e6237c3a0>"
1062
+ ]
1063
+ },
1064
+ "metadata": {},
1065
+ "output_type": "display_data"
1066
+ },
1067
+ {
1068
+ "data": {
1069
+ "text/html": [
1070
+ "<style type=\"text/css\">\n",
1071
+ "</style>\n",
1072
+ "<table id=\"T_87742\">\n",
1073
+ " <caption>GBC Validation Scores</caption>\n",
1074
+ " <thead>\n",
1075
+ " <tr>\n",
1076
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
1077
+ " <th id=\"T_87742_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
1078
+ " <th id=\"T_87742_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
1079
+ " <th id=\"T_87742_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
1080
+ " <th id=\"T_87742_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
1081
+ " <th id=\"T_87742_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
1082
+ " <th id=\"T_87742_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
1083
+ " </tr>\n",
1084
+ " </thead>\n",
1085
+ " <tbody>\n",
1086
+ " <tr>\n",
1087
+ " <th id=\"T_87742_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
1088
+ " <td id=\"T_87742_row0_col0\" class=\"data row0 col0\" >1.292703</td>\n",
1089
+ " <td id=\"T_87742_row0_col1\" class=\"data row0 col1\" >1.329037</td>\n",
1090
+ " <td id=\"T_87742_row0_col2\" class=\"data row0 col2\" >1.315473</td>\n",
1091
+ " <td id=\"T_87742_row0_col3\" class=\"data row0 col3\" >1.299452</td>\n",
1092
+ " <td id=\"T_87742_row0_col4\" class=\"data row0 col4\" >1.313950</td>\n",
1093
+ " <td id=\"T_87742_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
1094
+ " </tr>\n",
1095
+ " <tr>\n",
1096
+ " <th id=\"T_87742_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
1097
+ " <td id=\"T_87742_row1_col0\" class=\"data row1 col0\" >0.014560</td>\n",
1098
+ " <td id=\"T_87742_row1_col1\" class=\"data row1 col1\" >0.026127</td>\n",
1099
+ " <td id=\"T_87742_row1_col2\" class=\"data row1 col2\" >0.018812</td>\n",
1100
+ " <td id=\"T_87742_row1_col3\" class=\"data row1 col3\" >0.015999</td>\n",
1101
+ " <td id=\"T_87742_row1_col4\" class=\"data row1 col4\" >0.020753</td>\n",
1102
+ " <td id=\"T_87742_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
1103
+ " </tr>\n",
1104
+ " <tr>\n",
1105
+ " <th id=\"T_87742_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
1106
+ " <td id=\"T_87742_row2_col0\" class=\"data row2 col0\" >0.783193</td>\n",
1107
+ " <td id=\"T_87742_row2_col1\" class=\"data row2 col1\" >0.777871</td>\n",
1108
+ " <td id=\"T_87742_row2_col2\" class=\"data row2 col2\" >0.782292</td>\n",
1109
+ " <td id=\"T_87742_row2_col3\" class=\"data row2 col3\" >0.779770</td>\n",
1110
+ " <td id=\"T_87742_row2_col4\" class=\"data row2 col4\" >0.780050</td>\n",
1111
+ " <td id=\"T_87742_row2_col5\" class=\"data row2 col5\" >0.781830</td>\n",
1112
+ " </tr>\n",
1113
+ " <tr>\n",
1114
+ " <th id=\"T_87742_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
1115
+ " <td id=\"T_87742_row3_col0\" class=\"data row3 col0\" >0.112385</td>\n",
1116
+ " <td id=\"T_87742_row3_col1\" class=\"data row3 col1\" >0.072515</td>\n",
1117
+ " <td id=\"T_87742_row3_col2\" class=\"data row3 col2\" >0.091228</td>\n",
1118
+ " <td id=\"T_87742_row3_col3\" class=\"data row3 col3\" >0.075294</td>\n",
1119
+ " <td id=\"T_87742_row3_col4\" class=\"data row3 col4\" >0.081871</td>\n",
1120
+ " <td id=\"T_87742_row3_col5\" class=\"data row3 col5\" >0.085479</td>\n",
1121
+ " </tr>\n",
1122
+ " <tr>\n",
1123
+ " <th id=\"T_87742_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
1124
+ " <td id=\"T_87742_row4_col0\" class=\"data row4 col0\" >0.521277</td>\n",
1125
+ " <td id=\"T_87742_row4_col1\" class=\"data row4 col1\" >0.402597</td>\n",
1126
+ " <td id=\"T_87742_row4_col2\" class=\"data row4 col2\" >0.500000</td>\n",
1127
+ " <td id=\"T_87742_row4_col3\" class=\"data row4 col3\" >0.438356</td>\n",
1128
+ " <td id=\"T_87742_row4_col4\" class=\"data row4 col4\" >0.454545</td>\n",
1129
+ " <td id=\"T_87742_row4_col5\" class=\"data row4 col5\" >0.490566</td>\n",
1130
+ " </tr>\n",
1131
+ " <tr>\n",
1132
+ " <th id=\"T_87742_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
1133
+ " <td id=\"T_87742_row5_col0\" class=\"data row5 col0\" >0.062982</td>\n",
1134
+ " <td id=\"T_87742_row5_col1\" class=\"data row5 col1\" >0.039846</td>\n",
1135
+ " <td id=\"T_87742_row5_col2\" class=\"data row5 col2\" >0.050193</td>\n",
1136
+ " <td id=\"T_87742_row5_col3\" class=\"data row5 col3\" >0.041184</td>\n",
1137
+ " <td id=\"T_87742_row5_col4\" class=\"data row5 col4\" >0.044987</td>\n",
1138
+ " <td id=\"T_87742_row5_col5\" class=\"data row5 col5\" >0.046819</td>\n",
1139
+ " </tr>\n",
1140
+ " <tr>\n",
1141
+ " <th id=\"T_87742_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
1142
+ " <td id=\"T_87742_row6_col0\" class=\"data row6 col0\" >0.684278</td>\n",
1143
+ " <td id=\"T_87742_row6_col1\" class=\"data row6 col1\" >0.665449</td>\n",
1144
+ " <td id=\"T_87742_row6_col2\" class=\"data row6 col2\" >0.670804</td>\n",
1145
+ " <td id=\"T_87742_row6_col3\" class=\"data row6 col3\" >0.671522</td>\n",
1146
+ " <td id=\"T_87742_row6_col4\" class=\"data row6 col4\" >0.654303</td>\n",
1147
+ " <td id=\"T_87742_row6_col5\" class=\"data row6 col5\" >0.669660</td>\n",
1148
+ " </tr>\n",
1149
+ " <tr>\n",
1150
+ " <th id=\"T_87742_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
1151
+ " <td id=\"T_87742_row7_col0\" class=\"data row7 col0\" >-0.157619</td>\n",
1152
+ " <td id=\"T_87742_row7_col1\" class=\"data row7 col1\" >-0.160471</td>\n",
1153
+ " <td id=\"T_87742_row7_col2\" class=\"data row7 col2\" >-0.159413</td>\n",
1154
+ " <td id=\"T_87742_row7_col3\" class=\"data row7 col3\" >-0.159018</td>\n",
1155
+ " <td id=\"T_87742_row7_col4\" class=\"data row7 col4\" >-0.161594</td>\n",
1156
+ " <td id=\"T_87742_row7_col5\" class=\"data row7 col5\" >-0.159691</td>\n",
1157
+ " </tr>\n",
1158
+ " <tr>\n",
1159
+ " <th id=\"T_87742_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
1160
+ " <td id=\"T_87742_row8_col0\" class=\"data row8 col0\" >-0.488690</td>\n",
1161
+ " <td id=\"T_87742_row8_col1\" class=\"data row8 col1\" >-0.495751</td>\n",
1162
+ " <td id=\"T_87742_row8_col2\" class=\"data row8 col2\" >-0.493124</td>\n",
1163
+ " <td id=\"T_87742_row8_col3\" class=\"data row8 col3\" >-0.492051</td>\n",
1164
+ " <td id=\"T_87742_row8_col4\" class=\"data row8 col4\" >-0.498777</td>\n",
1165
+ " <td id=\"T_87742_row8_col5\" class=\"data row8 col5\" >-0.493648</td>\n",
1166
+ " </tr>\n",
1167
+ " </tbody>\n",
1168
+ "</table>\n"
1169
+ ],
1170
+ "text/plain": [
1171
+ "<pandas.io.formats.style.Styler at 0x21e0ab71420>"
1172
+ ]
1173
+ },
1174
+ "metadata": {},
1175
+ "output_type": "display_data"
1176
+ },
1177
+ {
1178
+ "data": {
1179
+ "text/html": [
1180
+ "<style type=\"text/css\">\n",
1181
+ "</style>\n",
1182
+ "<table id=\"T_0e44f\">\n",
1183
+ " <caption>XGB Validation Scores</caption>\n",
1184
+ " <thead>\n",
1185
+ " <tr>\n",
1186
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
1187
+ " <th id=\"T_0e44f_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
1188
+ " <th id=\"T_0e44f_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
1189
+ " <th id=\"T_0e44f_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
1190
+ " <th id=\"T_0e44f_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
1191
+ " <th id=\"T_0e44f_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
1192
+ " <th id=\"T_0e44f_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
1193
+ " </tr>\n",
1194
+ " </thead>\n",
1195
+ " <tbody>\n",
1196
+ " <tr>\n",
1197
+ " <th id=\"T_0e44f_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
1198
+ " <td id=\"T_0e44f_row0_col0\" class=\"data row0 col0\" >0.643858</td>\n",
1199
+ " <td id=\"T_0e44f_row0_col1\" class=\"data row0 col1\" >0.640153</td>\n",
1200
+ " <td id=\"T_0e44f_row0_col2\" class=\"data row0 col2\" >0.677121</td>\n",
1201
+ " <td id=\"T_0e44f_row0_col3\" class=\"data row0 col3\" >0.634137</td>\n",
1202
+ " <td id=\"T_0e44f_row0_col4\" class=\"data row0 col4\" >0.669338</td>\n",
1203
+ " <td id=\"T_0e44f_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
1204
+ " </tr>\n",
1205
+ " <tr>\n",
1206
+ " <th id=\"T_0e44f_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
1207
+ " <td id=\"T_0e44f_row1_col0\" class=\"data row1 col0\" >0.023605</td>\n",
1208
+ " <td id=\"T_0e44f_row1_col1\" class=\"data row1 col1\" >0.016412</td>\n",
1209
+ " <td id=\"T_0e44f_row1_col2\" class=\"data row1 col2\" >0.020805</td>\n",
1210
+ " <td id=\"T_0e44f_row1_col3\" class=\"data row1 col3\" >0.015040</td>\n",
1211
+ " <td id=\"T_0e44f_row1_col4\" class=\"data row1 col4\" >0.028892</td>\n",
1212
+ " <td id=\"T_0e44f_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
1213
+ " </tr>\n",
1214
+ " <tr>\n",
1215
+ " <th id=\"T_0e44f_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
1216
+ " <td id=\"T_0e44f_row2_col0\" class=\"data row2 col0\" >0.620168</td>\n",
1217
+ " <td id=\"T_0e44f_row2_col1\" class=\"data row2 col1\" >0.612605</td>\n",
1218
+ " <td id=\"T_0e44f_row2_col2\" class=\"data row2 col2\" >0.604371</td>\n",
1219
+ " <td id=\"T_0e44f_row2_col3\" class=\"data row2 col3\" >0.628748</td>\n",
1220
+ " <td id=\"T_0e44f_row2_col4\" class=\"data row2 col4\" >0.620062</td>\n",
1221
+ " <td id=\"T_0e44f_row2_col5\" class=\"data row2 col5\" >0.619216</td>\n",
1222
+ " </tr>\n",
1223
+ " <tr>\n",
1224
+ " <th id=\"T_0e44f_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
1225
+ " <td id=\"T_0e44f_row3_col0\" class=\"data row3 col0\" >0.397869</td>\n",
1226
+ " <td id=\"T_0e44f_row3_col1\" class=\"data row3 col1\" >0.386696</td>\n",
1227
+ " <td id=\"T_0e44f_row3_col2\" class=\"data row3 col2\" >0.377974</td>\n",
1228
+ " <td id=\"T_0e44f_row3_col3\" class=\"data row3 col3\" >0.386290</td>\n",
1229
+ " <td id=\"T_0e44f_row3_col4\" class=\"data row3 col4\" >0.383076</td>\n",
1230
+ " <td id=\"T_0e44f_row3_col5\" class=\"data row3 col5\" >0.387639</td>\n",
1231
+ " </tr>\n",
1232
+ " <tr>\n",
1233
+ " <th id=\"T_0e44f_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
1234
+ " <td id=\"T_0e44f_row4_col0\" class=\"data row4 col0\" >0.303935</td>\n",
1235
+ " <td id=\"T_0e44f_row4_col1\" class=\"data row4 col1\" >0.295193</td>\n",
1236
+ " <td id=\"T_0e44f_row4_col2\" class=\"data row4 col2\" >0.287341</td>\n",
1237
+ " <td id=\"T_0e44f_row4_col3\" class=\"data row4 col3\" >0.301737</td>\n",
1238
+ " <td id=\"T_0e44f_row4_col4\" class=\"data row4 col4\" >0.296479</td>\n",
1239
+ " <td id=\"T_0e44f_row4_col5\" class=\"data row4 col5\" >0.298285</td>\n",
1240
+ " </tr>\n",
1241
+ " <tr>\n",
1242
+ " <th id=\"T_0e44f_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
1243
+ " <td id=\"T_0e44f_row5_col0\" class=\"data row5 col0\" >0.575835</td>\n",
1244
+ " <td id=\"T_0e44f_row5_col1\" class=\"data row5 col1\" >0.560411</td>\n",
1245
+ " <td id=\"T_0e44f_row5_col2\" class=\"data row5 col2\" >0.552124</td>\n",
1246
+ " <td id=\"T_0e44f_row5_col3\" class=\"data row5 col3\" >0.536680</td>\n",
1247
+ " <td id=\"T_0e44f_row5_col4\" class=\"data row5 col4\" >0.541131</td>\n",
1248
+ " <td id=\"T_0e44f_row5_col5\" class=\"data row5 col5\" >0.553421</td>\n",
1249
+ " </tr>\n",
1250
+ " <tr>\n",
1251
+ " <th id=\"T_0e44f_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
1252
+ " <td id=\"T_0e44f_row6_col0\" class=\"data row6 col0\" >0.640317</td>\n",
1253
+ " <td id=\"T_0e44f_row6_col1\" class=\"data row6 col1\" >0.626210</td>\n",
1254
+ " <td id=\"T_0e44f_row6_col2\" class=\"data row6 col2\" >0.619924</td>\n",
1255
+ " <td id=\"T_0e44f_row6_col3\" class=\"data row6 col3\" >0.627078</td>\n",
1256
+ " <td id=\"T_0e44f_row6_col4\" class=\"data row6 col4\" >0.620747</td>\n",
1257
+ " <td id=\"T_0e44f_row6_col5\" class=\"data row6 col5\" >0.630028</td>\n",
1258
+ " </tr>\n",
1259
+ " <tr>\n",
1260
+ " <th id=\"T_0e44f_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
1261
+ " <td id=\"T_0e44f_row7_col0\" class=\"data row7 col0\" >-0.231951</td>\n",
1262
+ " <td id=\"T_0e44f_row7_col1\" class=\"data row7 col1\" >-0.238560</td>\n",
1263
+ " <td id=\"T_0e44f_row7_col2\" class=\"data row7 col2\" >-0.237950</td>\n",
1264
+ " <td id=\"T_0e44f_row7_col3\" class=\"data row7 col3\" >-0.233297</td>\n",
1265
+ " <td id=\"T_0e44f_row7_col4\" class=\"data row7 col4\" >-0.239657</td>\n",
1266
+ " <td id=\"T_0e44f_row7_col5\" class=\"data row7 col5\" >-0.236789</td>\n",
1267
+ " </tr>\n",
1268
+ " <tr>\n",
1269
+ " <th id=\"T_0e44f_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
1270
+ " <td id=\"T_0e44f_row8_col0\" class=\"data row8 col0\" >-0.660186</td>\n",
1271
+ " <td id=\"T_0e44f_row8_col1\" class=\"data row8 col1\" >-0.678450</td>\n",
1272
+ " <td id=\"T_0e44f_row8_col2\" class=\"data row8 col2\" >-0.673157</td>\n",
1273
+ " <td id=\"T_0e44f_row8_col3\" class=\"data row8 col3\" >-0.666774</td>\n",
1274
+ " <td id=\"T_0e44f_row8_col4\" class=\"data row8 col4\" >-0.681109</td>\n",
1275
+ " <td id=\"T_0e44f_row8_col5\" class=\"data row8 col5\" >-0.671979</td>\n",
1276
+ " </tr>\n",
1277
+ " </tbody>\n",
1278
+ "</table>\n"
1279
+ ],
1280
+ "text/plain": [
1281
+ "<pandas.io.formats.style.Styler at 0x21e0aab11e0>"
1282
+ ]
1283
+ },
1284
+ "metadata": {},
1285
+ "output_type": "display_data"
1286
+ }
1287
+ ],
1288
+ "source": [
1289
+ "# XGB hyperparameter that deals with unbalanced\n",
1290
+ "scale_pos_weight = Y.mean()**-1\n",
1291
+ "\n",
1292
+ "# Creating the model objects\n",
1293
+ "cls_lr = LogisticRegression(\n",
1294
+ " class_weight=\"balanced\", # Hyperparameter to deal with unbalanced output\n",
1295
+ " random_state=lucky_num)\n",
1296
+ "# cls_svm = SVC(random_state=lucky_num) # Remove due its resource consumption and worst results\n",
1297
+ "cls_nb = GaussianNB()\n",
1298
+ "cls_knn = KNeighborsClassifier()\n",
1299
+ "cls_rf = RandomForestClassifier(\n",
1300
+ " random_state=lucky_num,\n",
1301
+ " class_weight=\"balanced_subsample\") # Hyperparameter to deal with unbalanced output\n",
1302
+ "cls_gbc = GradientBoostingClassifier(random_state=lucky_num)\n",
1303
+ "cls_xgb = xgb.XGBClassifier(\n",
1304
+ " objective=\"binary:logistic\",\n",
1305
+ " verbose=None,\n",
1306
+ " random_state=lucky_num,\n",
1307
+ " scale_pos_weight = scale_pos_weight)\n",
1308
+ "\n",
1309
+ "# Lists to iterate on our modeling function\n",
1310
+ "cls_name = [\"LR\", \"NB\", \"KNN\", \"RF\", \"GBC\", \"XGB\"]\n",
1311
+ "cls_list = [cls_lr, cls_NB, cls_knn, cls_rf, cls_gbc, cls_xgb]\n",
1312
+ "\n",
1313
+ "mdl_summaries = []\n",
1314
+ "for name, inst in zip(cls_name, cls_list):\n",
1315
+ " mdl_list = create_model(name, inst)\n",
1316
+ " mdl_list = [name] + mdl_list\n",
1317
+ " mdl_summaries.append(mdl_list)\n",
1318
+ "\n",
1319
+ "df_mdl = pd.DataFrame(\n",
1320
+ " mdl_summaries,\n",
1321
+ " columns=[\n",
1322
+ " \"model\",\n",
1323
+ " \"test_accuracy\",\n",
1324
+ " \"test_f1\",\n",
1325
+ " \"test_precision\",\n",
1326
+ " \"test_recall\",\n",
1327
+ " \"test_roc_auc\",\n",
1328
+ " \"test_brier\",\n",
1329
+ " \"test_log_loss\"])"
1330
+ ]
1331
+ },
1332
+ {
1333
+ "cell_type": "code",
1334
+ "execution_count": 16,
1335
+ "metadata": {},
1336
+ "outputs": [
1337
+ {
1338
+ "data": {
1339
+ "text/html": [
1340
+ "<style type=\"text/css\">\n",
1341
+ "</style>\n",
1342
+ "<table id=\"T_0de30\">\n",
1343
+ " <caption>Test set validation scores</caption>\n",
1344
+ " <thead>\n",
1345
+ " <tr>\n",
1346
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
1347
+ " <th id=\"T_0de30_level0_col0\" class=\"col_heading level0 col0\" >model</th>\n",
1348
+ " <th id=\"T_0de30_level0_col1\" class=\"col_heading level0 col1\" >test_accuracy</th>\n",
1349
+ " <th id=\"T_0de30_level0_col2\" class=\"col_heading level0 col2\" >test_f1</th>\n",
1350
+ " <th id=\"T_0de30_level0_col3\" class=\"col_heading level0 col3\" >test_precision</th>\n",
1351
+ " <th id=\"T_0de30_level0_col4\" class=\"col_heading level0 col4\" >test_recall</th>\n",
1352
+ " <th id=\"T_0de30_level0_col5\" class=\"col_heading level0 col5\" >test_roc_auc</th>\n",
1353
+ " <th id=\"T_0de30_level0_col6\" class=\"col_heading level0 col6\" >test_brier</th>\n",
1354
+ " <th id=\"T_0de30_level0_col7\" class=\"col_heading level0 col7\" >test_log_loss</th>\n",
1355
+ " </tr>\n",
1356
+ " </thead>\n",
1357
+ " <tbody>\n",
1358
+ " <tr>\n",
1359
+ " <th id=\"T_0de30_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
1360
+ " <td id=\"T_0de30_row0_col0\" class=\"data row0 col0\" >LR</td>\n",
1361
+ " <td id=\"T_0de30_row0_col1\" class=\"data row0 col1\" >0.660523</td>\n",
1362
+ " <td id=\"T_0de30_row0_col2\" class=\"data row0 col2\" >0.414431</td>\n",
1363
+ " <td id=\"T_0de30_row0_col3\" class=\"data row0 col3\" >0.331889</td>\n",
1364
+ " <td id=\"T_0de30_row0_col4\" class=\"data row0 col4\" >0.551621</td>\n",
1365
+ " <td id=\"T_0de30_row0_col5\" class=\"data row0 col5\" >0.667540</td>\n",
1366
+ " <td id=\"T_0de30_row0_col6\" class=\"data row0 col6\" >0.225709</td>\n",
1367
+ " <td id=\"T_0de30_row0_col7\" class=\"data row0 col7\" >0.644816</td>\n",
1368
+ " </tr>\n",
1369
+ " <tr>\n",
1370
+ " <th id=\"T_0de30_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
1371
+ " <td id=\"T_0de30_row1_col0\" class=\"data row1 col0\" >XGB</td>\n",
1372
+ " <td id=\"T_0de30_row1_col1\" class=\"data row1 col1\" >0.619216</td>\n",
1373
+ " <td id=\"T_0de30_row1_col2\" class=\"data row1 col2\" >0.387639</td>\n",
1374
+ " <td id=\"T_0de30_row1_col3\" class=\"data row1 col3\" >0.298285</td>\n",
1375
+ " <td id=\"T_0de30_row1_col4\" class=\"data row1 col4\" >0.553421</td>\n",
1376
+ " <td id=\"T_0de30_row1_col5\" class=\"data row1 col5\" >0.630028</td>\n",
1377
+ " <td id=\"T_0de30_row1_col6\" class=\"data row1 col6\" >0.236789</td>\n",
1378
+ " <td id=\"T_0de30_row1_col7\" class=\"data row1 col7\" >0.671979</td>\n",
1379
+ " </tr>\n",
1380
+ " <tr>\n",
1381
+ " <th id=\"T_0de30_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
1382
+ " <td id=\"T_0de30_row2_col0\" class=\"data row2 col0\" >NB</td>\n",
1383
+ " <td id=\"T_0de30_row2_col1\" class=\"data row2 col1\" >0.699608</td>\n",
1384
+ " <td id=\"T_0de30_row2_col2\" class=\"data row2 col2\" >0.381260</td>\n",
1385
+ " <td id=\"T_0de30_row2_col3\" class=\"data row2 col3\" >0.345703</td>\n",
1386
+ " <td id=\"T_0de30_row2_col4\" class=\"data row2 col4\" >0.424970</td>\n",
1387
+ " <td id=\"T_0de30_row2_col5\" class=\"data row2 col5\" >0.634031</td>\n",
1388
+ " <td id=\"T_0de30_row2_col6\" class=\"data row2 col6\" >0.268468</td>\n",
1389
+ " <td id=\"T_0de30_row2_col7\" class=\"data row2 col7\" >1.659029</td>\n",
1390
+ " </tr>\n",
1391
+ " <tr>\n",
1392
+ " <th id=\"T_0de30_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
1393
+ " <td id=\"T_0de30_row3_col0\" class=\"data row3 col0\" >RF</td>\n",
1394
+ " <td id=\"T_0de30_row3_col1\" class=\"data row3 col1\" >0.733595</td>\n",
1395
+ " <td id=\"T_0de30_row3_col2\" class=\"data row3 col2\" >0.242379</td>\n",
1396
+ " <td id=\"T_0de30_row3_col3\" class=\"data row3 col3\" >0.318359</td>\n",
1397
+ " <td id=\"T_0de30_row3_col4\" class=\"data row3 col4\" >0.195678</td>\n",
1398
+ " <td id=\"T_0de30_row3_col5\" class=\"data row3 col5\" >0.601436</td>\n",
1399
+ " <td id=\"T_0de30_row3_col6\" class=\"data row3 col6\" >0.188506</td>\n",
1400
+ " <td id=\"T_0de30_row3_col7\" class=\"data row3 col7\" >0.743245</td>\n",
1401
+ " </tr>\n",
1402
+ " <tr>\n",
1403
+ " <th id=\"T_0de30_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
1404
+ " <td id=\"T_0de30_row4_col0\" class=\"data row4 col0\" >KNN</td>\n",
1405
+ " <td id=\"T_0de30_row4_col1\" class=\"data row4 col1\" >0.754379</td>\n",
1406
+ " <td id=\"T_0de30_row4_col2\" class=\"data row4 col2\" >0.214136</td>\n",
1407
+ " <td id=\"T_0de30_row4_col3\" class=\"data row4 col3\" >0.353103</td>\n",
1408
+ " <td id=\"T_0de30_row4_col4\" class=\"data row4 col4\" >0.153661</td>\n",
1409
+ " <td id=\"T_0de30_row4_col5\" class=\"data row4 col5\" >0.582555</td>\n",
1410
+ " <td id=\"T_0de30_row4_col6\" class=\"data row4 col6\" >0.189077</td>\n",
1411
+ " <td id=\"T_0de30_row4_col7\" class=\"data row4 col7\" >2.428060</td>\n",
1412
+ " </tr>\n",
1413
+ " <tr>\n",
1414
+ " <th id=\"T_0de30_level0_row5\" class=\"row_heading level0 row5\" >5</th>\n",
1415
+ " <td id=\"T_0de30_row5_col0\" class=\"data row5 col0\" >GBC</td>\n",
1416
+ " <td id=\"T_0de30_row5_col1\" class=\"data row5 col1\" >0.781830</td>\n",
1417
+ " <td id=\"T_0de30_row5_col2\" class=\"data row5 col2\" >0.085479</td>\n",
1418
+ " <td id=\"T_0de30_row5_col3\" class=\"data row5 col3\" >0.490566</td>\n",
1419
+ " <td id=\"T_0de30_row5_col4\" class=\"data row5 col4\" >0.046819</td>\n",
1420
+ " <td id=\"T_0de30_row5_col5\" class=\"data row5 col5\" >0.669660</td>\n",
1421
+ " <td id=\"T_0de30_row5_col6\" class=\"data row5 col6\" >0.159691</td>\n",
1422
+ " <td id=\"T_0de30_row5_col7\" class=\"data row5 col7\" >0.493648</td>\n",
1423
+ " </tr>\n",
1424
+ " </tbody>\n",
1425
+ "</table>\n"
1426
+ ],
1427
+ "text/plain": [
1428
+ "<pandas.io.formats.style.Styler at 0x21e044360e0>"
1429
+ ]
1430
+ },
1431
+ "metadata": {},
1432
+ "output_type": "display_data"
1433
+ }
1434
+ ],
1435
+ "source": [
1436
+ "df_mdl.sort_values(\n",
1437
+ " \"test_f1\",\n",
1438
+ " ascending=False,\n",
1439
+ " inplace=True,\n",
1440
+ " ignore_index=True)\n",
1441
+ "\n",
1442
+ "display(df_mdl.style.set_caption(\"Test set validation scores\"))"
1443
+ ]
1444
+ },
1445
+ {
1446
+ "cell_type": "markdown",
1447
+ "metadata": {},
1448
+ "source": [
1449
+ "Any of models present good results! We will try to fit a composite model with the 3 better."
1450
+ ]
1451
+ },
1452
+ {
1453
+ "cell_type": "code",
1454
+ "execution_count": 21,
1455
+ "metadata": {},
1456
+ "outputs": [
1457
+ {
1458
+ "data": {
1459
+ "text/html": [
1460
+ "<style type=\"text/css\">\n",
1461
+ "</style>\n",
1462
+ "<table id=\"T_ca02f\">\n",
1463
+ " <caption>Test set validation scores for Composite Model</caption>\n",
1464
+ " <thead>\n",
1465
+ " <tr>\n",
1466
+ " <th class=\"blank level0\" >&nbsp;</th>\n",
1467
+ " <th id=\"T_ca02f_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
1468
+ " <th id=\"T_ca02f_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
1469
+ " <th id=\"T_ca02f_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
1470
+ " <th id=\"T_ca02f_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
1471
+ " <th id=\"T_ca02f_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
1472
+ " <th id=\"T_ca02f_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
1473
+ " </tr>\n",
1474
+ " </thead>\n",
1475
+ " <tbody>\n",
1476
+ " <tr>\n",
1477
+ " <th id=\"T_ca02f_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
1478
+ " <td id=\"T_ca02f_row0_col0\" class=\"data row0 col0\" >0.622165</td>\n",
1479
+ " <td id=\"T_ca02f_row0_col1\" class=\"data row0 col1\" >0.732543</td>\n",
1480
+ " <td id=\"T_ca02f_row0_col2\" class=\"data row0 col2\" >0.591849</td>\n",
1481
+ " <td id=\"T_ca02f_row0_col3\" class=\"data row0 col3\" >0.699149</td>\n",
1482
+ " <td id=\"T_ca02f_row0_col4\" class=\"data row0 col4\" >0.617794</td>\n",
1483
+ " <td id=\"T_ca02f_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
1484
+ " </tr>\n",
1485
+ " <tr>\n",
1486
+ " <th id=\"T_ca02f_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
1487
+ " <td id=\"T_ca02f_row1_col0\" class=\"data row1 col0\" >0.023785</td>\n",
1488
+ " <td id=\"T_ca02f_row1_col1\" class=\"data row1 col1\" >0.027777</td>\n",
1489
+ " <td id=\"T_ca02f_row1_col2\" class=\"data row1 col2\" >0.030991</td>\n",
1490
+ " <td id=\"T_ca02f_row1_col3\" class=\"data row1 col3\" >0.027807</td>\n",
1491
+ " <td id=\"T_ca02f_row1_col4\" class=\"data row1 col4\" >0.023930</td>\n",
1492
+ " <td id=\"T_ca02f_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
1493
+ " </tr>\n",
1494
+ " <tr>\n",
1495
+ " <th id=\"T_ca02f_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
1496
+ " <td id=\"T_ca02f_row2_col0\" class=\"data row2 col0\" >0.714846</td>\n",
1497
+ " <td id=\"T_ca02f_row2_col1\" class=\"data row2 col1\" >0.701120</td>\n",
1498
+ " <td id=\"T_ca02f_row2_col2\" class=\"data row2 col2\" >0.695713</td>\n",
1499
+ " <td id=\"T_ca02f_row2_col3\" class=\"data row2 col3\" >0.693752</td>\n",
1500
+ " <td id=\"T_ca02f_row2_col4\" class=\"data row2 col4\" >0.677221</td>\n",
1501
+ " <td id=\"T_ca02f_row2_col5\" class=\"data row2 col5\" >0.699346</td>\n",
1502
+ " </tr>\n",
1503
+ " <tr>\n",
1504
+ " <th id=\"T_ca02f_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
1505
+ " <td id=\"T_ca02f_row3_col0\" class=\"data row3 col0\" >0.412240</td>\n",
1506
+ " <td id=\"T_ca02f_row3_col1\" class=\"data row3 col1\" >0.404243</td>\n",
1507
+ " <td id=\"T_ca02f_row3_col2\" class=\"data row3 col2\" >0.389201</td>\n",
1508
+ " <td id=\"T_ca02f_row3_col3\" class=\"data row3 col3\" >0.385610</td>\n",
1509
+ " <td id=\"T_ca02f_row3_col4\" class=\"data row3 col4\" >0.369803</td>\n",
1510
+ " <td id=\"T_ca02f_row3_col5\" class=\"data row3 col5\" >0.389597</td>\n",
1511
+ " </tr>\n",
1512
+ " <tr>\n",
1513
+ " <th id=\"T_ca02f_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
1514
+ " <td id=\"T_ca02f_row4_col0\" class=\"data row4 col0\" >0.374214</td>\n",
1515
+ " <td id=\"T_ca02f_row4_col1\" class=\"data row4 col1\" >0.357354</td>\n",
1516
+ " <td id=\"T_ca02f_row4_col2\" class=\"data row4 col2\" >0.345654</td>\n",
1517
+ " <td id=\"T_ca02f_row4_col3\" class=\"data row4 col3\" >0.342315</td>\n",
1518
+ " <td id=\"T_ca02f_row4_col4\" class=\"data row4 col4\" >0.321905</td>\n",
1519
+ " <td id=\"T_ca02f_row4_col5\" class=\"data row4 col5\" >0.349191</td>\n",
1520
+ " </tr>\n",
1521
+ " <tr>\n",
1522
+ " <th id=\"T_ca02f_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
1523
+ " <td id=\"T_ca02f_row5_col0\" class=\"data row5 col0\" >0.458869</td>\n",
1524
+ " <td id=\"T_ca02f_row5_col1\" class=\"data row5 col1\" >0.465296</td>\n",
1525
+ " <td id=\"T_ca02f_row5_col2\" class=\"data row5 col2\" >0.445302</td>\n",
1526
+ " <td id=\"T_ca02f_row5_col3\" class=\"data row5 col3\" >0.441441</td>\n",
1527
+ " <td id=\"T_ca02f_row5_col4\" class=\"data row5 col4\" >0.434447</td>\n",
1528
+ " <td id=\"T_ca02f_row5_col5\" class=\"data row5 col5\" >0.440576</td>\n",
1529
+ " </tr>\n",
1530
+ " <tr>\n",
1531
+ " <th id=\"T_ca02f_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
1532
+ " <td id=\"T_ca02f_row6_col0\" class=\"data row6 col0\" >0.679117</td>\n",
1533
+ " <td id=\"T_ca02f_row6_col1\" class=\"data row6 col1\" >0.662225</td>\n",
1534
+ " <td id=\"T_ca02f_row6_col2\" class=\"data row6 col2\" >0.651689</td>\n",
1535
+ " <td id=\"T_ca02f_row6_col3\" class=\"data row6 col3\" >0.658722</td>\n",
1536
+ " <td id=\"T_ca02f_row6_col4\" class=\"data row6 col4\" >0.640881</td>\n",
1537
+ " <td id=\"T_ca02f_row6_col5\" class=\"data row6 col5\" >0.658847</td>\n",
1538
+ " </tr>\n",
1539
+ " <tr>\n",
1540
+ " <th id=\"T_ca02f_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
1541
+ " <td id=\"T_ca02f_row7_col0\" class=\"data row7 col0\" >-0.199904</td>\n",
1542
+ " <td id=\"T_ca02f_row7_col1\" class=\"data row7 col1\" >-0.208466</td>\n",
1543
+ " <td id=\"T_ca02f_row7_col2\" class=\"data row7 col2\" >-0.211428</td>\n",
1544
+ " <td id=\"T_ca02f_row7_col3\" class=\"data row7 col3\" >-0.209929</td>\n",
1545
+ " <td id=\"T_ca02f_row7_col4\" class=\"data row7 col4\" >-0.218236</td>\n",
1546
+ " <td id=\"T_ca02f_row7_col5\" class=\"data row7 col5\" >-0.208991</td>\n",
1547
+ " </tr>\n",
1548
+ " <tr>\n",
1549
+ " <th id=\"T_ca02f_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
1550
+ " <td id=\"T_ca02f_row8_col0\" class=\"data row8 col0\" >-0.590700</td>\n",
1551
+ " <td id=\"T_ca02f_row8_col1\" class=\"data row8 col1\" >-0.611156</td>\n",
1552
+ " <td id=\"T_ca02f_row8_col2\" class=\"data row8 col2\" >-0.616876</td>\n",
1553
+ " <td id=\"T_ca02f_row8_col3\" class=\"data row8 col3\" >-0.613285</td>\n",
1554
+ " <td id=\"T_ca02f_row8_col4\" class=\"data row8 col4\" >-0.633314</td>\n",
1555
+ " <td id=\"T_ca02f_row8_col5\" class=\"data row8 col5\" >-0.611563</td>\n",
1556
+ " </tr>\n",
1557
+ " </tbody>\n",
1558
+ "</table>\n"
1559
+ ],
1560
+ "text/plain": [
1561
+ "<pandas.io.formats.style.Styler at 0x21e04333100>"
1562
+ ]
1563
+ },
1564
+ "metadata": {},
1565
+ "output_type": "display_data"
1566
+ }
1567
+ ],
1568
+ "source": [
1569
+ "# Selecting the models\n",
1570
+ "cls_name = [\"LR\", \"NB\", \"XGB\"]\n",
1571
+ "cls_list = [cls_lr, cls_nb, cls_xgb]\n",
1572
+ "\n",
1573
+ "# Training the voting classifier\n",
1574
+ "cls_vot = VotingClassifier([*zip(cls_name, cls_list)], voting=\"soft\")\n",
1575
+ "cls_vot.fit(X_train, y_train)\n",
1576
+ "\n",
1577
+ "# Using cross-validation to evaluate the model fitted\n",
1578
+ "cls_cross = cross_validate(\n",
1579
+ " estimator=cls_vot,\n",
1580
+ " X=X_train,\n",
1581
+ " y=y_train,\n",
1582
+ " cv=5,\n",
1583
+ " scoring=scores)\n",
1584
+ "\n",
1585
+ "df_vot = pd.DataFrame.from_dict(cls_cross, orient='index', columns=[\"CV\"+str(i) for i in range(1,6)])\n",
1586
+ "\n",
1587
+ "# Calculating score to test set\n",
1588
+ "accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value = eval_model(cls_vot)\n",
1589
+ "\n",
1590
+ "# Filling a dataframe to better presentation\n",
1591
+ "df_vot.at[\"test_accuracy\", \"TestSet\"] = accurancy\n",
1592
+ "df_vot.at[\"test_f1\", \"TestSet\"] = f1\n",
1593
+ "df_vot.at[\"test_recall\", \"TestSet\"] = recall\n",
1594
+ "df_vot.at[\"test_precision\", \"TestSet\"] = precision\n",
1595
+ "df_vot.at[\"test_roc_auc\", \"TestSet\"] = roc_auc\n",
1596
+ "df_vot.at[\"test_neg_brier_score\", \"TestSet\"] = -brier_score\n",
1597
+ "df_vot.at[\"test_neg_log_loss\", \"TestSet\"] = -log_loss_value\n",
1598
+ "\n",
1599
+ "display(df_vot.style.set_caption(\"Test set validation scores for Composite Model\"))"
1600
+ ]
1601
+ },
1602
+ {
1603
+ "cell_type": "markdown",
1604
+ "metadata": {},
1605
+ "source": [
1606
+ "The composite model is not better than neat models. Well, maybe some tuning could handle this. But this will be done in future work."
1607
+ ]
1608
+ },
1609
+ {
1610
+ "cell_type": "code",
1611
+ "execution_count": 11,
1612
+ "metadata": {},
1613
+ "outputs": [
1614
+ {
1615
+ "data": {
1616
+ "text/plain": [
1617
+ "['c:\\\\Users\\\\grego\\\\OneDrive\\\\Documentos\\\\Documentos Pessoais\\\\00_DataCamp\\\\09_VSC\\\\poa_car_accidents\\\\poa_car_accidents\\\\model\\\\model_feridos.pkl']"
1618
+ ]
1619
+ },
1620
+ "execution_count": 11,
1621
+ "metadata": {},
1622
+ "output_type": "execute_result"
1623
+ }
1624
+ ],
1625
+ "source": [
1626
+ "# Saving\n",
1627
+ "# file_name = \"model_\" + output + '.pkl'\n",
1628
+ "# jb.dump(cls_vot, path.join(path.abspath(\"./\"), file_name))"
1629
+ ]
1630
+ }
1631
+ ],
1632
+ "metadata": {
1633
+ "kernelspec": {
1634
+ "display_name": "Python 3.10.6 64-bit",
1635
+ "language": "python",
1636
+ "name": "python3"
1637
+ },
1638
+ "language_info": {
1639
+ "codemirror_mode": {
1640
+ "name": "ipython",
1641
+ "version": 3
1642
+ },
1643
+ "file_extension": ".py",
1644
+ "mimetype": "text/x-python",
1645
+ "name": "python",
1646
+ "nbconvert_exporter": "python",
1647
+ "pygments_lexer": "ipython3",
1648
+ "version": "3.10.6"
1649
+ },
1650
+ "orig_nbformat": 4,
1651
+ "vscode": {
1652
+ "interpreter": {
1653
+ "hash": "1372d04dbd71fdc5436c5d6e671c1b9287e750e86143c81b5a7ba0acaf653c5e"
1654
+ }
1655
+ }
1656
+ },
1657
+ "nbformat": 4,
1658
+ "nbformat_minor": 2
1659
+ }
model/scaler_feridos.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf819a52c70bb4058de5851aedbd92da179246c132e87593420f87850f7ae1b0
3
+ size 2387
model/scaler_feridos_gr.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:063891bf93f83929da6294c645de8d503fcec2c756cde601933800c7d3dc6598
3
+ size 2387
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ joblib
2
+ numpy
3
+ pandas
4
+ gradio
5
+ scikit-learn
6
+ xgboost=1.6.2