Spaces:
Runtime error
Runtime error
GregOliveira
commited on
Commit
•
9605944
1
Parent(s):
cedfa43
app v0
Browse files- app.py +149 -0
- model/model_features.csv +1 -0
- model/model_feridos.ipynb +1659 -0
- model/model_feridos.pkl +3 -0
- model/model_feridos_gr.ipynb +1659 -0
- model/scaler_feridos.pkl +3 -0
- model/scaler_feridos_gr.pkl +3 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import gradio as gr
|
5 |
+
import joblib as jb
|
6 |
+
import os.path as path
|
7 |
+
|
8 |
+
import warnings
|
9 |
+
warnings.filterwarnings("ignore")
|
10 |
+
|
11 |
+
# Loading features names
|
12 |
+
file_csv = path.join("model" ,"model_features.csv")
|
13 |
+
with open(file_csv) as f:
|
14 |
+
reader = csv.reader(f)
|
15 |
+
data = list(reader)
|
16 |
+
|
17 |
+
features = data[0]
|
18 |
+
|
19 |
+
# Creating a list with accident types
|
20 |
+
accident_type_list = [None,
|
21 |
+
"type_ATROPELAMENTO",
|
22 |
+
"type_CHOQUE",
|
23 |
+
"type_COLISÃO",
|
24 |
+
"type_OUTROS"]
|
25 |
+
|
26 |
+
# Loading the scaler
|
27 |
+
file_scaler_feridos = path.join("model" ,"scaler_feridos.pkl")
|
28 |
+
scaler_feridos = jb.load(file_scaler_feridos)
|
29 |
+
|
30 |
+
# Loading the model
|
31 |
+
file_model_feridos = path.join("model" ,"model_feridos.pkl")
|
32 |
+
model_feridos = jb.load(file_model_feridos)
|
33 |
+
|
34 |
+
def fit_inputs_injured(latitude,
|
35 |
+
longitude,
|
36 |
+
caminhao,
|
37 |
+
moto,
|
38 |
+
cars,
|
39 |
+
transport,
|
40 |
+
others,
|
41 |
+
holiday,
|
42 |
+
week_day,
|
43 |
+
hour_day,
|
44 |
+
accident_type) -> np.array:
|
45 |
+
"""This function will process data input
|
46 |
+
from use to use in the model"""
|
47 |
+
input_dict = {col: False for col in features}
|
48 |
+
|
49 |
+
input_dict["latitude"] = latitude
|
50 |
+
input_dict["longitude"] = longitude
|
51 |
+
input_dict["caminhao"] = caminhao
|
52 |
+
input_dict["moto"] = moto
|
53 |
+
input_dict["cars"] = cars
|
54 |
+
input_dict["transport"] = transport
|
55 |
+
input_dict["others"] = others
|
56 |
+
input_dict["holiday"] = holiday
|
57 |
+
|
58 |
+
if week_day != 0:
|
59 |
+
input_dict["day_" + str(week_day)] = True
|
60 |
+
|
61 |
+
if hour_day != 0:
|
62 |
+
input_dict["hour_" + str(hour_day)] = True
|
63 |
+
|
64 |
+
if accident_type != 0:
|
65 |
+
input_dict[accident_type_list[accident_type]] = True
|
66 |
+
|
67 |
+
input_series = pd.Series(input_dict)
|
68 |
+
|
69 |
+
input_array = input_series.to_numpy().reshape(1,-1)
|
70 |
+
|
71 |
+
input_scaled = scaler_feridos.transform(input_array)
|
72 |
+
|
73 |
+
return input_scaled
|
74 |
+
|
75 |
+
def predict(
|
76 |
+
latitude,
|
77 |
+
longitude,
|
78 |
+
caminhao,
|
79 |
+
moto,
|
80 |
+
cars,
|
81 |
+
transport,
|
82 |
+
others,
|
83 |
+
holiday,
|
84 |
+
week_day,
|
85 |
+
hour_day,
|
86 |
+
accident_type) -> dict:
|
87 |
+
"""This function will be call by gradio
|
88 |
+
when on submit action."""
|
89 |
+
|
90 |
+
input_to_predict = fit_inputs_injured(latitude,
|
91 |
+
longitude,
|
92 |
+
caminhao,
|
93 |
+
moto,
|
94 |
+
cars,
|
95 |
+
transport,
|
96 |
+
others,
|
97 |
+
holiday,
|
98 |
+
week_day,
|
99 |
+
hour_day,
|
100 |
+
accident_type)
|
101 |
+
|
102 |
+
predic_injured = model_feridos.predict_proba(input_to_predict)
|
103 |
+
|
104 |
+
return {"No": predic_injured[0][0], "Yes": predic_injured[0][1]}
|
105 |
+
|
106 |
+
demo = gr.Interface(
|
107 |
+
fn=predict,
|
108 |
+
inputs=[gr.Slider(
|
109 |
+
minimum=-31.054,
|
110 |
+
maximum=-29.054,
|
111 |
+
step=0.001,
|
112 |
+
value=-30.054,
|
113 |
+
label="Latitude"),
|
114 |
+
gr.Slider(
|
115 |
+
minimum=-52.196,
|
116 |
+
maximum=-50.196,
|
117 |
+
step=0.001,
|
118 |
+
value=-51.196,
|
119 |
+
label="Longitude"),
|
120 |
+
gr.Checkbox(label="Trucks involved?"),
|
121 |
+
gr.Checkbox(label="Motorcycle involved?"),
|
122 |
+
gr.Checkbox(label="Cars involved?"),
|
123 |
+
gr.Checkbox(label="Bus involved?"),
|
124 |
+
gr.Checkbox(label="Other vehicle (i.e. scooter) involved?"),
|
125 |
+
gr.Checkbox(label="Is holiday?"),
|
126 |
+
gr.Radio(
|
127 |
+
choices=["Sun", "Mon",
|
128 |
+
"Tue", "Wed",
|
129 |
+
"Thu", "Fri",
|
130 |
+
"Sat"],
|
131 |
+
type="index",
|
132 |
+
label="Day of Week"),
|
133 |
+
gr.Slider(
|
134 |
+
minimum=0,
|
135 |
+
maximum=23,
|
136 |
+
step=1,
|
137 |
+
label="Hour"),
|
138 |
+
gr.Dropdown(
|
139 |
+
choices=["Violent Collision",
|
140 |
+
"Running over",
|
141 |
+
"Shock",
|
142 |
+
"Collision",
|
143 |
+
"Other"],
|
144 |
+
type="index",
|
145 |
+
label="Accident type")],
|
146 |
+
outputs=gr.Label(
|
147 |
+
label="Are there people injured?"))
|
148 |
+
|
149 |
+
demo.launch()
|
model/model_features.csv
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
latitude,longitude,caminhao,moto,cars,transport,others,holiday,day_1,day_2,day_3,day_4,day_5,day_6,hour_1,hour_2,hour_3,hour_4,hour_5,hour_6,hour_7,hour_8,hour_9,hour_10,hour_11,hour_12,hour_13,hour_14,hour_15,hour_16,hour_17,hour_18,hour_19,hour_20,hour_21,hour_22,hour_23,type_ATROPELAMENTO,type_CHOQUE,type_COLIS�O,type_OUTROS
|
model/model_feridos.ipynb
ADDED
@@ -0,0 +1,1659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# 1. Introduction\n",
|
8 |
+
"\n",
|
9 |
+
"This notebook was written to train Porto Alegre Traffic Accidents Data after the first cleaning, processing, and transforming step. This was made in a notebook in the `data` folder. In truth, we will have 3 models.\n",
|
10 |
+
"\n",
|
11 |
+
"1. Predict the probability of injured people.\n",
|
12 |
+
"\n",
|
13 |
+
"2. Predict the probability of seriously injured people.\n",
|
14 |
+
"\n",
|
15 |
+
"3. Predict the probability of dead people in the event or after it.\n",
|
16 |
+
"\n",
|
17 |
+
"The path to training the models will be the same, just make some filtering on data and analyze the results properly."
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "markdown",
|
22 |
+
"metadata": {},
|
23 |
+
"source": [
|
24 |
+
"# 2. Data Loading"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": 1,
|
30 |
+
"metadata": {},
|
31 |
+
"outputs": [
|
32 |
+
{
|
33 |
+
"data": {
|
34 |
+
"text/html": [
|
35 |
+
"<div>\n",
|
36 |
+
"<style scoped>\n",
|
37 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
38 |
+
" vertical-align: middle;\n",
|
39 |
+
" }\n",
|
40 |
+
"\n",
|
41 |
+
" .dataframe tbody tr th {\n",
|
42 |
+
" vertical-align: top;\n",
|
43 |
+
" }\n",
|
44 |
+
"\n",
|
45 |
+
" .dataframe thead th {\n",
|
46 |
+
" text-align: right;\n",
|
47 |
+
" }\n",
|
48 |
+
"</style>\n",
|
49 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
50 |
+
" <thead>\n",
|
51 |
+
" <tr style=\"text-align: right;\">\n",
|
52 |
+
" <th></th>\n",
|
53 |
+
" <th>0</th>\n",
|
54 |
+
" <th>1</th>\n",
|
55 |
+
" <th>2</th>\n",
|
56 |
+
" </tr>\n",
|
57 |
+
" </thead>\n",
|
58 |
+
" <tbody>\n",
|
59 |
+
" <tr>\n",
|
60 |
+
" <th>latitude</th>\n",
|
61 |
+
" <td>-30.009614</td>\n",
|
62 |
+
" <td>-30.0403</td>\n",
|
63 |
+
" <td>-30.069</td>\n",
|
64 |
+
" </tr>\n",
|
65 |
+
" <tr>\n",
|
66 |
+
" <th>longitude</th>\n",
|
67 |
+
" <td>-51.185581</td>\n",
|
68 |
+
" <td>-51.1958</td>\n",
|
69 |
+
" <td>-51.1437</td>\n",
|
70 |
+
" </tr>\n",
|
71 |
+
" <tr>\n",
|
72 |
+
" <th>feridos</th>\n",
|
73 |
+
" <td>True</td>\n",
|
74 |
+
" <td>True</td>\n",
|
75 |
+
" <td>True</td>\n",
|
76 |
+
" </tr>\n",
|
77 |
+
" <tr>\n",
|
78 |
+
" <th>feridos_gr</th>\n",
|
79 |
+
" <td>False</td>\n",
|
80 |
+
" <td>False</td>\n",
|
81 |
+
" <td>False</td>\n",
|
82 |
+
" </tr>\n",
|
83 |
+
" <tr>\n",
|
84 |
+
" <th>fatais</th>\n",
|
85 |
+
" <td>False</td>\n",
|
86 |
+
" <td>False</td>\n",
|
87 |
+
" <td>False</td>\n",
|
88 |
+
" </tr>\n",
|
89 |
+
" <tr>\n",
|
90 |
+
" <th>caminhao</th>\n",
|
91 |
+
" <td>False</td>\n",
|
92 |
+
" <td>False</td>\n",
|
93 |
+
" <td>False</td>\n",
|
94 |
+
" </tr>\n",
|
95 |
+
" <tr>\n",
|
96 |
+
" <th>moto</th>\n",
|
97 |
+
" <td>True</td>\n",
|
98 |
+
" <td>True</td>\n",
|
99 |
+
" <td>False</td>\n",
|
100 |
+
" </tr>\n",
|
101 |
+
" <tr>\n",
|
102 |
+
" <th>cars</th>\n",
|
103 |
+
" <td>True</td>\n",
|
104 |
+
" <td>True</td>\n",
|
105 |
+
" <td>True</td>\n",
|
106 |
+
" </tr>\n",
|
107 |
+
" <tr>\n",
|
108 |
+
" <th>transport</th>\n",
|
109 |
+
" <td>False</td>\n",
|
110 |
+
" <td>False</td>\n",
|
111 |
+
" <td>False</td>\n",
|
112 |
+
" </tr>\n",
|
113 |
+
" <tr>\n",
|
114 |
+
" <th>others</th>\n",
|
115 |
+
" <td>False</td>\n",
|
116 |
+
" <td>False</td>\n",
|
117 |
+
" <td>False</td>\n",
|
118 |
+
" </tr>\n",
|
119 |
+
" <tr>\n",
|
120 |
+
" <th>holiday</th>\n",
|
121 |
+
" <td>False</td>\n",
|
122 |
+
" <td>True</td>\n",
|
123 |
+
" <td>True</td>\n",
|
124 |
+
" </tr>\n",
|
125 |
+
" <tr>\n",
|
126 |
+
" <th>day_1</th>\n",
|
127 |
+
" <td>0</td>\n",
|
128 |
+
" <td>0</td>\n",
|
129 |
+
" <td>0</td>\n",
|
130 |
+
" </tr>\n",
|
131 |
+
" <tr>\n",
|
132 |
+
" <th>day_2</th>\n",
|
133 |
+
" <td>0</td>\n",
|
134 |
+
" <td>0</td>\n",
|
135 |
+
" <td>0</td>\n",
|
136 |
+
" </tr>\n",
|
137 |
+
" <tr>\n",
|
138 |
+
" <th>day_3</th>\n",
|
139 |
+
" <td>0</td>\n",
|
140 |
+
" <td>0</td>\n",
|
141 |
+
" <td>0</td>\n",
|
142 |
+
" </tr>\n",
|
143 |
+
" <tr>\n",
|
144 |
+
" <th>day_4</th>\n",
|
145 |
+
" <td>0</td>\n",
|
146 |
+
" <td>0</td>\n",
|
147 |
+
" <td>0</td>\n",
|
148 |
+
" </tr>\n",
|
149 |
+
" <tr>\n",
|
150 |
+
" <th>day_5</th>\n",
|
151 |
+
" <td>1</td>\n",
|
152 |
+
" <td>0</td>\n",
|
153 |
+
" <td>0</td>\n",
|
154 |
+
" </tr>\n",
|
155 |
+
" <tr>\n",
|
156 |
+
" <th>day_6</th>\n",
|
157 |
+
" <td>0</td>\n",
|
158 |
+
" <td>1</td>\n",
|
159 |
+
" <td>1</td>\n",
|
160 |
+
" </tr>\n",
|
161 |
+
" <tr>\n",
|
162 |
+
" <th>hour_1</th>\n",
|
163 |
+
" <td>0</td>\n",
|
164 |
+
" <td>0</td>\n",
|
165 |
+
" <td>0</td>\n",
|
166 |
+
" </tr>\n",
|
167 |
+
" <tr>\n",
|
168 |
+
" <th>hour_2</th>\n",
|
169 |
+
" <td>0</td>\n",
|
170 |
+
" <td>0</td>\n",
|
171 |
+
" <td>0</td>\n",
|
172 |
+
" </tr>\n",
|
173 |
+
" <tr>\n",
|
174 |
+
" <th>hour_3</th>\n",
|
175 |
+
" <td>0</td>\n",
|
176 |
+
" <td>0</td>\n",
|
177 |
+
" <td>0</td>\n",
|
178 |
+
" </tr>\n",
|
179 |
+
" <tr>\n",
|
180 |
+
" <th>hour_4</th>\n",
|
181 |
+
" <td>0</td>\n",
|
182 |
+
" <td>0</td>\n",
|
183 |
+
" <td>0</td>\n",
|
184 |
+
" </tr>\n",
|
185 |
+
" <tr>\n",
|
186 |
+
" <th>hour_5</th>\n",
|
187 |
+
" <td>0</td>\n",
|
188 |
+
" <td>0</td>\n",
|
189 |
+
" <td>0</td>\n",
|
190 |
+
" </tr>\n",
|
191 |
+
" <tr>\n",
|
192 |
+
" <th>hour_6</th>\n",
|
193 |
+
" <td>0</td>\n",
|
194 |
+
" <td>0</td>\n",
|
195 |
+
" <td>0</td>\n",
|
196 |
+
" </tr>\n",
|
197 |
+
" <tr>\n",
|
198 |
+
" <th>hour_7</th>\n",
|
199 |
+
" <td>0</td>\n",
|
200 |
+
" <td>0</td>\n",
|
201 |
+
" <td>0</td>\n",
|
202 |
+
" </tr>\n",
|
203 |
+
" <tr>\n",
|
204 |
+
" <th>hour_8</th>\n",
|
205 |
+
" <td>0</td>\n",
|
206 |
+
" <td>0</td>\n",
|
207 |
+
" <td>0</td>\n",
|
208 |
+
" </tr>\n",
|
209 |
+
" <tr>\n",
|
210 |
+
" <th>hour_9</th>\n",
|
211 |
+
" <td>0</td>\n",
|
212 |
+
" <td>0</td>\n",
|
213 |
+
" <td>0</td>\n",
|
214 |
+
" </tr>\n",
|
215 |
+
" <tr>\n",
|
216 |
+
" <th>hour_10</th>\n",
|
217 |
+
" <td>0</td>\n",
|
218 |
+
" <td>1</td>\n",
|
219 |
+
" <td>0</td>\n",
|
220 |
+
" </tr>\n",
|
221 |
+
" <tr>\n",
|
222 |
+
" <th>hour_11</th>\n",
|
223 |
+
" <td>0</td>\n",
|
224 |
+
" <td>0</td>\n",
|
225 |
+
" <td>0</td>\n",
|
226 |
+
" </tr>\n",
|
227 |
+
" <tr>\n",
|
228 |
+
" <th>hour_12</th>\n",
|
229 |
+
" <td>0</td>\n",
|
230 |
+
" <td>0</td>\n",
|
231 |
+
" <td>0</td>\n",
|
232 |
+
" </tr>\n",
|
233 |
+
" <tr>\n",
|
234 |
+
" <th>hour_13</th>\n",
|
235 |
+
" <td>0</td>\n",
|
236 |
+
" <td>0</td>\n",
|
237 |
+
" <td>0</td>\n",
|
238 |
+
" </tr>\n",
|
239 |
+
" <tr>\n",
|
240 |
+
" <th>hour_14</th>\n",
|
241 |
+
" <td>0</td>\n",
|
242 |
+
" <td>0</td>\n",
|
243 |
+
" <td>0</td>\n",
|
244 |
+
" </tr>\n",
|
245 |
+
" <tr>\n",
|
246 |
+
" <th>hour_15</th>\n",
|
247 |
+
" <td>0</td>\n",
|
248 |
+
" <td>0</td>\n",
|
249 |
+
" <td>0</td>\n",
|
250 |
+
" </tr>\n",
|
251 |
+
" <tr>\n",
|
252 |
+
" <th>hour_16</th>\n",
|
253 |
+
" <td>0</td>\n",
|
254 |
+
" <td>0</td>\n",
|
255 |
+
" <td>0</td>\n",
|
256 |
+
" </tr>\n",
|
257 |
+
" <tr>\n",
|
258 |
+
" <th>hour_17</th>\n",
|
259 |
+
" <td>0</td>\n",
|
260 |
+
" <td>0</td>\n",
|
261 |
+
" <td>0</td>\n",
|
262 |
+
" </tr>\n",
|
263 |
+
" <tr>\n",
|
264 |
+
" <th>hour_18</th>\n",
|
265 |
+
" <td>0</td>\n",
|
266 |
+
" <td>0</td>\n",
|
267 |
+
" <td>0</td>\n",
|
268 |
+
" </tr>\n",
|
269 |
+
" <tr>\n",
|
270 |
+
" <th>hour_19</th>\n",
|
271 |
+
" <td>1</td>\n",
|
272 |
+
" <td>0</td>\n",
|
273 |
+
" <td>1</td>\n",
|
274 |
+
" </tr>\n",
|
275 |
+
" <tr>\n",
|
276 |
+
" <th>hour_20</th>\n",
|
277 |
+
" <td>0</td>\n",
|
278 |
+
" <td>0</td>\n",
|
279 |
+
" <td>0</td>\n",
|
280 |
+
" </tr>\n",
|
281 |
+
" <tr>\n",
|
282 |
+
" <th>hour_21</th>\n",
|
283 |
+
" <td>0</td>\n",
|
284 |
+
" <td>0</td>\n",
|
285 |
+
" <td>0</td>\n",
|
286 |
+
" </tr>\n",
|
287 |
+
" <tr>\n",
|
288 |
+
" <th>hour_22</th>\n",
|
289 |
+
" <td>0</td>\n",
|
290 |
+
" <td>0</td>\n",
|
291 |
+
" <td>0</td>\n",
|
292 |
+
" </tr>\n",
|
293 |
+
" <tr>\n",
|
294 |
+
" <th>hour_23</th>\n",
|
295 |
+
" <td>0</td>\n",
|
296 |
+
" <td>0</td>\n",
|
297 |
+
" <td>0</td>\n",
|
298 |
+
" </tr>\n",
|
299 |
+
" <tr>\n",
|
300 |
+
" <th>type_ATROPELAMENTO</th>\n",
|
301 |
+
" <td>0</td>\n",
|
302 |
+
" <td>0</td>\n",
|
303 |
+
" <td>1</td>\n",
|
304 |
+
" </tr>\n",
|
305 |
+
" <tr>\n",
|
306 |
+
" <th>type_CHOQUE</th>\n",
|
307 |
+
" <td>0</td>\n",
|
308 |
+
" <td>0</td>\n",
|
309 |
+
" <td>0</td>\n",
|
310 |
+
" </tr>\n",
|
311 |
+
" <tr>\n",
|
312 |
+
" <th>type_COLISÃO</th>\n",
|
313 |
+
" <td>0</td>\n",
|
314 |
+
" <td>0</td>\n",
|
315 |
+
" <td>0</td>\n",
|
316 |
+
" </tr>\n",
|
317 |
+
" <tr>\n",
|
318 |
+
" <th>type_OUTROS</th>\n",
|
319 |
+
" <td>0</td>\n",
|
320 |
+
" <td>0</td>\n",
|
321 |
+
" <td>0</td>\n",
|
322 |
+
" </tr>\n",
|
323 |
+
" </tbody>\n",
|
324 |
+
"</table>\n",
|
325 |
+
"</div>"
|
326 |
+
],
|
327 |
+
"text/plain": [
|
328 |
+
" 0 1 2\n",
|
329 |
+
"latitude -30.009614 -30.0403 -30.069\n",
|
330 |
+
"longitude -51.185581 -51.1958 -51.1437\n",
|
331 |
+
"feridos True True True\n",
|
332 |
+
"feridos_gr False False False\n",
|
333 |
+
"fatais False False False\n",
|
334 |
+
"caminhao False False False\n",
|
335 |
+
"moto True True False\n",
|
336 |
+
"cars True True True\n",
|
337 |
+
"transport False False False\n",
|
338 |
+
"others False False False\n",
|
339 |
+
"holiday False True True\n",
|
340 |
+
"day_1 0 0 0\n",
|
341 |
+
"day_2 0 0 0\n",
|
342 |
+
"day_3 0 0 0\n",
|
343 |
+
"day_4 0 0 0\n",
|
344 |
+
"day_5 1 0 0\n",
|
345 |
+
"day_6 0 1 1\n",
|
346 |
+
"hour_1 0 0 0\n",
|
347 |
+
"hour_2 0 0 0\n",
|
348 |
+
"hour_3 0 0 0\n",
|
349 |
+
"hour_4 0 0 0\n",
|
350 |
+
"hour_5 0 0 0\n",
|
351 |
+
"hour_6 0 0 0\n",
|
352 |
+
"hour_7 0 0 0\n",
|
353 |
+
"hour_8 0 0 0\n",
|
354 |
+
"hour_9 0 0 0\n",
|
355 |
+
"hour_10 0 1 0\n",
|
356 |
+
"hour_11 0 0 0\n",
|
357 |
+
"hour_12 0 0 0\n",
|
358 |
+
"hour_13 0 0 0\n",
|
359 |
+
"hour_14 0 0 0\n",
|
360 |
+
"hour_15 0 0 0\n",
|
361 |
+
"hour_16 0 0 0\n",
|
362 |
+
"hour_17 0 0 0\n",
|
363 |
+
"hour_18 0 0 0\n",
|
364 |
+
"hour_19 1 0 1\n",
|
365 |
+
"hour_20 0 0 0\n",
|
366 |
+
"hour_21 0 0 0\n",
|
367 |
+
"hour_22 0 0 0\n",
|
368 |
+
"hour_23 0 0 0\n",
|
369 |
+
"type_ATROPELAMENTO 0 0 1\n",
|
370 |
+
"type_CHOQUE 0 0 0\n",
|
371 |
+
"type_COLISÃO 0 0 0\n",
|
372 |
+
"type_OUTROS 0 0 0"
|
373 |
+
]
|
374 |
+
},
|
375 |
+
"execution_count": 1,
|
376 |
+
"metadata": {},
|
377 |
+
"output_type": "execute_result"
|
378 |
+
}
|
379 |
+
],
|
380 |
+
"source": [
|
381 |
+
"import os.path as path\n",
|
382 |
+
"from pandas import read_csv\n",
|
383 |
+
"\n",
|
384 |
+
"file_csv = path.abspath(\"../\")\n",
|
385 |
+
"\n",
|
386 |
+
"file_csv = path.join(file_csv, \"data\" ,\"accidents_trans.csv\")\n",
|
387 |
+
"\n",
|
388 |
+
"accidents_trans = read_csv(file_csv)\n",
|
389 |
+
"\n",
|
390 |
+
"accidents_trans.head(3).T"
|
391 |
+
]
|
392 |
+
},
|
393 |
+
{
|
394 |
+
"cell_type": "markdown",
|
395 |
+
"metadata": {},
|
396 |
+
"source": [
|
397 |
+
"# 3. Data Preparation"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "code",
|
402 |
+
"execution_count": 2,
|
403 |
+
"metadata": {},
|
404 |
+
"outputs": [],
|
405 |
+
"source": [
|
406 |
+
"import joblib as jb # Use to save the model to deploy\n",
|
407 |
+
"from sklearn.preprocessing import StandardScaler\n",
|
408 |
+
"from sklearn.model_selection import train_test_split"
|
409 |
+
]
|
410 |
+
},
|
411 |
+
{
|
412 |
+
"cell_type": "code",
|
413 |
+
"execution_count": 3,
|
414 |
+
"metadata": {},
|
415 |
+
"outputs": [
|
416 |
+
{
|
417 |
+
"name": "stdout",
|
418 |
+
"output_type": "stream",
|
419 |
+
"text": [
|
420 |
+
"Our model to predict the probability of feridos will be create with 68218 rows and 41 features.\n"
|
421 |
+
]
|
422 |
+
}
|
423 |
+
],
|
424 |
+
"source": [
|
425 |
+
"outputs = [\"feridos\", \"feridos_gr\", \"fatais\"]\n",
|
426 |
+
"inputs = [col for col in accidents_trans.columns if col not in outputs]\n",
|
427 |
+
"\n",
|
428 |
+
"X = accidents_trans[inputs].copy()\n",
|
429 |
+
"Y = accidents_trans[outputs].copy()\n",
|
430 |
+
"\n",
|
431 |
+
"# Filtering data considering the output\n",
|
432 |
+
"output = \"feridos\"\n",
|
433 |
+
"\n",
|
434 |
+
"if output == \"feridos_gr\":\n",
|
435 |
+
" X = X[Y[\"feridos\"]]\n",
|
436 |
+
" Y = Y.loc[Y[\"feridos\"], \"feridos_gr\"]\n",
|
437 |
+
"elif output == \"fatais\":\n",
|
438 |
+
" X = X[Y[\"feridos_gr\"]]\n",
|
439 |
+
" Y = Y.loc[Y[\"feridos_gr\"], \"fatais\"]\n",
|
440 |
+
"else:\n",
|
441 |
+
" Y = Y[\"feridos\"]\n",
|
442 |
+
"\n",
|
443 |
+
"print(f\"Our model to predict the probability of \" \\\n",
|
444 |
+
" f\"{output} will be create with {X.shape[0]} \" \\\n",
|
445 |
+
" f\"rows and {X.shape[1]} features.\")"
|
446 |
+
]
|
447 |
+
},
|
448 |
+
{
|
449 |
+
"cell_type": "code",
|
450 |
+
"execution_count": 4,
|
451 |
+
"metadata": {},
|
452 |
+
"outputs": [],
|
453 |
+
"source": [
|
454 |
+
"import csv\n",
|
455 |
+
"\n",
|
456 |
+
"with open(\"model_features.csv\", 'w') as f:\n",
|
457 |
+
" writer = csv.writer(f)\n",
|
458 |
+
" writer.writerow(X.columns)"
|
459 |
+
]
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"cell_type": "markdown",
|
463 |
+
"metadata": {},
|
464 |
+
"source": [
|
465 |
+
"Considering that we will use models scaling sensitive, we will need to scale our data first. Beside this, we will need to save our scaler for future use."
|
466 |
+
]
|
467 |
+
},
|
468 |
+
{
|
469 |
+
"cell_type": "code",
|
470 |
+
"execution_count": 5,
|
471 |
+
"metadata": {},
|
472 |
+
"outputs": [
|
473 |
+
{
|
474 |
+
"data": {
|
475 |
+
"text/plain": [
|
476 |
+
"['c:\\\\Users\\\\grego\\\\OneDrive\\\\Documentos\\\\Documentos Pessoais\\\\00_DataCamp\\\\09_VSC\\\\poa_car_accidents\\\\poa_car_accidents\\\\model\\\\scaler_feridos.pkl']"
|
477 |
+
]
|
478 |
+
},
|
479 |
+
"execution_count": 5,
|
480 |
+
"metadata": {},
|
481 |
+
"output_type": "execute_result"
|
482 |
+
}
|
483 |
+
],
|
484 |
+
"source": [
|
485 |
+
"# Setting the random state using my luck number :-)\n",
|
486 |
+
"lucky_num = 7\n",
|
487 |
+
"\n",
|
488 |
+
"# X_train and y_train to train our model\n",
|
489 |
+
"X_train, X_test, y_train, y_test = train_test_split(\n",
|
490 |
+
" X,\n",
|
491 |
+
" Y,\n",
|
492 |
+
" test_size=0.30,\n",
|
493 |
+
" random_state=lucky_num,\n",
|
494 |
+
" shuffle=True, # Used because our data is sort by date\n",
|
495 |
+
" stratify=Y) # Used because our data is unbalanced\n",
|
496 |
+
"\n",
|
497 |
+
"# Scaling\n",
|
498 |
+
"scaler = StandardScaler()\n",
|
499 |
+
"X_train = scaler.fit_transform(X_train)\n",
|
500 |
+
"X_test = scaler.transform(X_test)\n",
|
501 |
+
"\n",
|
502 |
+
"# Saving scaler\n",
|
503 |
+
"file_name = \"scaler_\" + output + '.pkl'\n",
|
504 |
+
"jb.dump(scaler, path.join(path.abspath(\"./\"), file_name))"
|
505 |
+
]
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"cell_type": "markdown",
|
509 |
+
"metadata": {},
|
510 |
+
"source": [
|
511 |
+
"# 4. Data Modeling\n",
|
512 |
+
"\n",
|
513 |
+
"We will create and use cross-validation to evaluate the following models:\n",
|
514 |
+
"\n",
|
515 |
+
"- Logistic Regression;\n",
|
516 |
+
"\n",
|
517 |
+
"- Gaussian Naive Bayes;\n",
|
518 |
+
"\n",
|
519 |
+
"- K Neighbors;\n",
|
520 |
+
"\n",
|
521 |
+
"- Random Forest;\n",
|
522 |
+
"\n",
|
523 |
+
"- Gradient Boosting; and,\n",
|
524 |
+
"\n",
|
525 |
+
"- XGBoost.\n",
|
526 |
+
"\n",
|
527 |
+
"We will use two scores to select and evaluate our models:\n",
|
528 |
+
"\n",
|
529 |
+
"- F1 score: composition between the precision (how much our model correct classify every true label) and recall (how moch our model correct indicate true labels); and,\n",
|
530 |
+
"\n",
|
531 |
+
"- Brier score: average between the correct and the predict probability.\n",
|
532 |
+
"\n",
|
533 |
+
"However, we will see other metrics to support our decision:\n",
|
534 |
+
"\n",
|
535 |
+
"- Accurancy;\n",
|
536 |
+
"\n",
|
537 |
+
"- ROC_AOC; and,\n",
|
538 |
+
"\n",
|
539 |
+
"- Log loss (an other way to quantify the quality of probability predictions).\n",
|
540 |
+
"\n",
|
541 |
+
"And, before you go, we will find for each model if there is a hyperparameter to deal with the unbalanced output."
|
542 |
+
]
|
543 |
+
},
|
544 |
+
{
|
545 |
+
"cell_type": "code",
|
546 |
+
"execution_count": 6,
|
547 |
+
"metadata": {},
|
548 |
+
"outputs": [],
|
549 |
+
"source": [
|
550 |
+
"import pandas as pd\n",
|
551 |
+
"import xgboost as xgb\n",
|
552 |
+
"from sklearn.naive_bayes import GaussianNB\n",
|
553 |
+
"from sklearn.neighbors import KNeighborsClassifier\n",
|
554 |
+
"from sklearn.linear_model import LogisticRegression\n",
|
555 |
+
"from sklearn.model_selection import cross_validate \n",
|
556 |
+
"from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier\n",
|
557 |
+
"from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_auc_score, f1_score, brier_score_loss, log_loss\n",
|
558 |
+
"\n",
|
559 |
+
"scores = [\"accuracy\", \"f1\", \"precision\", \"recall\", \"roc_auc\", \"neg_brier_score\",\"neg_log_loss\"]"
|
560 |
+
]
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"cell_type": "code",
|
564 |
+
"execution_count": 7,
|
565 |
+
"metadata": {},
|
566 |
+
"outputs": [],
|
567 |
+
"source": [
|
568 |
+
"def eval_model(cls) -> tuple:\n",
|
569 |
+
" \"\"\"This function will calculate the metrics\n",
|
570 |
+
" to evaluate a classification model.\n",
|
571 |
+
" \"\"\"\n",
|
572 |
+
" # Predicting labels and probabilities\n",
|
573 |
+
" y_pred = cls.predict(X_test)\n",
|
574 |
+
" y_prob = cls.predict_proba(X_test)[:,1]\n",
|
575 |
+
"\n",
|
576 |
+
" # Calculating scores\n",
|
577 |
+
" accurancy = accuracy_score(y_test, y_pred)\n",
|
578 |
+
" f1 = f1_score(y_test, y_pred)\n",
|
579 |
+
" recall = recall_score(y_test, y_pred)\n",
|
580 |
+
" precision = precision_score(y_test, y_pred)\n",
|
581 |
+
" roc_auc = roc_auc_score(y_test, y_prob) # https://datascience.stackexchange.com/questions/114394/does-roc-auc-different-between-crossval-and-test-set-indicate-overfitting-or-oth\n",
|
582 |
+
" brier_score = brier_score_loss(y_test, y_prob)\n",
|
583 |
+
" log_loss_value = log_loss(y_test, y_prob)\n",
|
584 |
+
"\n",
|
585 |
+
" return accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value\n",
|
586 |
+
"\n",
|
587 |
+
"def create_model(name: str, cls) -> list:\n",
|
588 |
+
" \"\"\"This function will create some models\n",
|
589 |
+
" and return scores to evaluate it.\"\"\"\n",
|
590 |
+
" # Ftting model\n",
|
591 |
+
" cls.fit(X_train, y_train)\n",
|
592 |
+
"\n",
|
593 |
+
" # Using cross-validation to evaluate the model fitted\n",
|
594 |
+
" cls_cross = cross_validate(\n",
|
595 |
+
" estimator=cls,\n",
|
596 |
+
" X=X_train,\n",
|
597 |
+
" y=y_train,\n",
|
598 |
+
" cv=5,\n",
|
599 |
+
" scoring=scores)\n",
|
600 |
+
"\n",
|
601 |
+
" df_cv = pd.DataFrame.from_dict(cls_cross, orient='index', columns=[\"CV\"+str(i) for i in range(1,6)])\n",
|
602 |
+
"\n",
|
603 |
+
" # Calculating score to test set\n",
|
604 |
+
" accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value = eval_model(cls)\n",
|
605 |
+
"\n",
|
606 |
+
" # Filling a dataframe to better presentation\n",
|
607 |
+
" df_cv.at[\"test_accuracy\", \"TestSet\"] = accurancy\n",
|
608 |
+
" df_cv.at[\"test_f1\", \"TestSet\"] = f1\n",
|
609 |
+
" df_cv.at[\"test_recall\", \"TestSet\"] = recall\n",
|
610 |
+
" df_cv.at[\"test_precision\", \"TestSet\"] = precision\n",
|
611 |
+
" df_cv.at[\"test_roc_auc\", \"TestSet\"] = roc_auc\n",
|
612 |
+
" df_cv.at[\"test_neg_brier_score\", \"TestSet\"] = -brier_score\n",
|
613 |
+
" df_cv.at[\"test_neg_log_loss\", \"TestSet\"] = -log_loss_value\n",
|
614 |
+
"\n",
|
615 |
+
" caption = f\"{name} Validation Scores\"\n",
|
616 |
+
"\n",
|
617 |
+
" display(df_cv.style.set_caption(caption))\n",
|
618 |
+
"\n",
|
619 |
+
" return [accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value]"
|
620 |
+
]
|
621 |
+
},
|
622 |
+
{
|
623 |
+
"cell_type": "code",
|
624 |
+
"execution_count": 8,
|
625 |
+
"metadata": {},
|
626 |
+
"outputs": [
|
627 |
+
{
|
628 |
+
"data": {
|
629 |
+
"text/html": [
|
630 |
+
"<style type=\"text/css\">\n",
|
631 |
+
"</style>\n",
|
632 |
+
"<table id=\"T_31097\">\n",
|
633 |
+
" <caption>LR Validation Scores</caption>\n",
|
634 |
+
" <thead>\n",
|
635 |
+
" <tr>\n",
|
636 |
+
" <th class=\"blank level0\" > </th>\n",
|
637 |
+
" <th id=\"T_31097_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
638 |
+
" <th id=\"T_31097_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
639 |
+
" <th id=\"T_31097_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
640 |
+
" <th id=\"T_31097_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
641 |
+
" <th id=\"T_31097_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
642 |
+
" <th id=\"T_31097_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
643 |
+
" </tr>\n",
|
644 |
+
" </thead>\n",
|
645 |
+
" <tbody>\n",
|
646 |
+
" <tr>\n",
|
647 |
+
" <th id=\"T_31097_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
648 |
+
" <td id=\"T_31097_row0_col0\" class=\"data row0 col0\" >0.082354</td>\n",
|
649 |
+
" <td id=\"T_31097_row0_col1\" class=\"data row0 col1\" >0.080257</td>\n",
|
650 |
+
" <td id=\"T_31097_row0_col2\" class=\"data row0 col2\" >0.089329</td>\n",
|
651 |
+
" <td id=\"T_31097_row0_col3\" class=\"data row0 col3\" >0.094720</td>\n",
|
652 |
+
" <td id=\"T_31097_row0_col4\" class=\"data row0 col4\" >0.087742</td>\n",
|
653 |
+
" <td id=\"T_31097_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
654 |
+
" </tr>\n",
|
655 |
+
" <tr>\n",
|
656 |
+
" <th id=\"T_31097_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
657 |
+
" <td id=\"T_31097_row1_col0\" class=\"data row1 col0\" >0.016066</td>\n",
|
658 |
+
" <td id=\"T_31097_row1_col1\" class=\"data row1 col1\" >0.017635</td>\n",
|
659 |
+
" <td id=\"T_31097_row1_col2\" class=\"data row1 col2\" >0.020100</td>\n",
|
660 |
+
" <td id=\"T_31097_row1_col3\" class=\"data row1 col3\" >0.018260</td>\n",
|
661 |
+
" <td id=\"T_31097_row1_col4\" class=\"data row1 col4\" >0.018356</td>\n",
|
662 |
+
" <td id=\"T_31097_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
663 |
+
" </tr>\n",
|
664 |
+
" <tr>\n",
|
665 |
+
" <th id=\"T_31097_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
666 |
+
" <td id=\"T_31097_row2_col0\" class=\"data row2 col0\" >0.869228</td>\n",
|
667 |
+
" <td id=\"T_31097_row2_col1\" class=\"data row2 col1\" >0.868391</td>\n",
|
668 |
+
" <td id=\"T_31097_row2_col2\" class=\"data row2 col2\" >0.872356</td>\n",
|
669 |
+
" <td id=\"T_31097_row2_col3\" class=\"data row2 col3\" >0.869005</td>\n",
|
670 |
+
" <td id=\"T_31097_row2_col4\" class=\"data row2 col4\" >0.867539</td>\n",
|
671 |
+
" <td id=\"T_31097_row2_col5\" class=\"data row2 col5\" >0.865924</td>\n",
|
672 |
+
" </tr>\n",
|
673 |
+
" <tr>\n",
|
674 |
+
" <th id=\"T_31097_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
675 |
+
" <td id=\"T_31097_row3_col0\" class=\"data row3 col0\" >0.817584</td>\n",
|
676 |
+
" <td id=\"T_31097_row3_col1\" class=\"data row3 col1\" >0.818116</td>\n",
|
677 |
+
" <td id=\"T_31097_row3_col2\" class=\"data row3 col2\" >0.823920</td>\n",
|
678 |
+
" <td id=\"T_31097_row3_col3\" class=\"data row3 col3\" >0.819611</td>\n",
|
679 |
+
" <td id=\"T_31097_row3_col4\" class=\"data row3 col4\" >0.817011</td>\n",
|
680 |
+
" <td id=\"T_31097_row3_col5\" class=\"data row3 col5\" >0.814469</td>\n",
|
681 |
+
" </tr>\n",
|
682 |
+
" <tr>\n",
|
683 |
+
" <th id=\"T_31097_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
684 |
+
" <td id=\"T_31097_row4_col0\" class=\"data row4 col0\" >0.854135</td>\n",
|
685 |
+
" <td id=\"T_31097_row4_col1\" class=\"data row4 col1\" >0.846154</td>\n",
|
686 |
+
" <td id=\"T_31097_row4_col2\" class=\"data row4 col2\" >0.850582</td>\n",
|
687 |
+
" <td id=\"T_31097_row4_col3\" class=\"data row4 col3\" >0.844326</td>\n",
|
688 |
+
" <td id=\"T_31097_row4_col4\" class=\"data row4 col4\" >0.844498</td>\n",
|
689 |
+
" <td id=\"T_31097_row4_col5\" class=\"data row4 col5\" >0.843439</td>\n",
|
690 |
+
" </tr>\n",
|
691 |
+
" <tr>\n",
|
692 |
+
" <th id=\"T_31097_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
693 |
+
" <td id=\"T_31097_row5_col0\" class=\"data row5 col0\" >0.784034</td>\n",
|
694 |
+
" <td id=\"T_31097_row5_col1\" class=\"data row5 col1\" >0.791877</td>\n",
|
695 |
+
" <td id=\"T_31097_row5_col2\" class=\"data row5 col2\" >0.798880</td>\n",
|
696 |
+
" <td id=\"T_31097_row5_col3\" class=\"data row5 col3\" >0.796301</td>\n",
|
697 |
+
" <td id=\"T_31097_row5_col4\" class=\"data row5 col4\" >0.791258</td>\n",
|
698 |
+
" <td id=\"T_31097_row5_col5\" class=\"data row5 col5\" >0.787423</td>\n",
|
699 |
+
" </tr>\n",
|
700 |
+
" <tr>\n",
|
701 |
+
" <th id=\"T_31097_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
702 |
+
" <td id=\"T_31097_row6_col0\" class=\"data row6 col0\" >0.903418</td>\n",
|
703 |
+
" <td id=\"T_31097_row6_col1\" class=\"data row6 col1\" >0.904970</td>\n",
|
704 |
+
" <td id=\"T_31097_row6_col2\" class=\"data row6 col2\" >0.906377</td>\n",
|
705 |
+
" <td id=\"T_31097_row6_col3\" class=\"data row6 col3\" >0.902405</td>\n",
|
706 |
+
" <td id=\"T_31097_row6_col4\" class=\"data row6 col4\" >0.906939</td>\n",
|
707 |
+
" <td id=\"T_31097_row6_col5\" class=\"data row6 col5\" >0.904458</td>\n",
|
708 |
+
" </tr>\n",
|
709 |
+
" <tr>\n",
|
710 |
+
" <th id=\"T_31097_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
711 |
+
" <td id=\"T_31097_row7_col0\" class=\"data row7 col0\" >-0.109808</td>\n",
|
712 |
+
" <td id=\"T_31097_row7_col1\" class=\"data row7 col1\" >-0.109221</td>\n",
|
713 |
+
" <td id=\"T_31097_row7_col2\" class=\"data row7 col2\" >-0.106382</td>\n",
|
714 |
+
" <td id=\"T_31097_row7_col3\" class=\"data row7 col3\" >-0.110939</td>\n",
|
715 |
+
" <td id=\"T_31097_row7_col4\" class=\"data row7 col4\" >-0.109709</td>\n",
|
716 |
+
" <td id=\"T_31097_row7_col5\" class=\"data row7 col5\" >-0.110435</td>\n",
|
717 |
+
" </tr>\n",
|
718 |
+
" <tr>\n",
|
719 |
+
" <th id=\"T_31097_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
720 |
+
" <td id=\"T_31097_row8_col0\" class=\"data row8 col0\" >-0.370200</td>\n",
|
721 |
+
" <td id=\"T_31097_row8_col1\" class=\"data row8 col1\" >-0.366684</td>\n",
|
722 |
+
" <td id=\"T_31097_row8_col2\" class=\"data row8 col2\" >-0.360534</td>\n",
|
723 |
+
" <td id=\"T_31097_row8_col3\" class=\"data row8 col3\" >-0.372374</td>\n",
|
724 |
+
" <td id=\"T_31097_row8_col4\" class=\"data row8 col4\" >-0.367532</td>\n",
|
725 |
+
" <td id=\"T_31097_row8_col5\" class=\"data row8 col5\" >-0.370350</td>\n",
|
726 |
+
" </tr>\n",
|
727 |
+
" </tbody>\n",
|
728 |
+
"</table>\n"
|
729 |
+
],
|
730 |
+
"text/plain": [
|
731 |
+
"<pandas.io.formats.style.Styler at 0x1a1a6427220>"
|
732 |
+
]
|
733 |
+
},
|
734 |
+
"metadata": {},
|
735 |
+
"output_type": "display_data"
|
736 |
+
},
|
737 |
+
{
|
738 |
+
"data": {
|
739 |
+
"text/html": [
|
740 |
+
"<style type=\"text/css\">\n",
|
741 |
+
"</style>\n",
|
742 |
+
"<table id=\"T_750d2\">\n",
|
743 |
+
" <caption>NB Validation Scores</caption>\n",
|
744 |
+
" <thead>\n",
|
745 |
+
" <tr>\n",
|
746 |
+
" <th class=\"blank level0\" > </th>\n",
|
747 |
+
" <th id=\"T_750d2_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
748 |
+
" <th id=\"T_750d2_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
749 |
+
" <th id=\"T_750d2_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
750 |
+
" <th id=\"T_750d2_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
751 |
+
" <th id=\"T_750d2_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
752 |
+
" <th id=\"T_750d2_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
753 |
+
" </tr>\n",
|
754 |
+
" </thead>\n",
|
755 |
+
" <tbody>\n",
|
756 |
+
" <tr>\n",
|
757 |
+
" <th id=\"T_750d2_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
758 |
+
" <td id=\"T_750d2_row0_col0\" class=\"data row0 col0\" >0.035410</td>\n",
|
759 |
+
" <td id=\"T_750d2_row0_col1\" class=\"data row0 col1\" >0.030015</td>\n",
|
760 |
+
" <td id=\"T_750d2_row0_col2\" class=\"data row0 col2\" >0.032639</td>\n",
|
761 |
+
" <td id=\"T_750d2_row0_col3\" class=\"data row0 col3\" >0.029752</td>\n",
|
762 |
+
" <td id=\"T_750d2_row0_col4\" class=\"data row0 col4\" >0.030653</td>\n",
|
763 |
+
" <td id=\"T_750d2_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
764 |
+
" </tr>\n",
|
765 |
+
" <tr>\n",
|
766 |
+
" <th id=\"T_750d2_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
767 |
+
" <td id=\"T_750d2_row1_col0\" class=\"data row1 col0\" >0.037826</td>\n",
|
768 |
+
" <td id=\"T_750d2_row1_col1\" class=\"data row1 col1\" >0.040993</td>\n",
|
769 |
+
" <td id=\"T_750d2_row1_col2\" class=\"data row1 col2\" >0.032376</td>\n",
|
770 |
+
" <td id=\"T_750d2_row1_col3\" class=\"data row1 col3\" >0.030767</td>\n",
|
771 |
+
" <td id=\"T_750d2_row1_col4\" class=\"data row1 col4\" >0.028092</td>\n",
|
772 |
+
" <td id=\"T_750d2_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
773 |
+
" </tr>\n",
|
774 |
+
" <tr>\n",
|
775 |
+
" <th id=\"T_750d2_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
776 |
+
" <td id=\"T_750d2_row2_col0\" class=\"data row2 col0\" >0.768401</td>\n",
|
777 |
+
" <td id=\"T_750d2_row2_col1\" class=\"data row2 col1\" >0.763376</td>\n",
|
778 |
+
" <td id=\"T_750d2_row2_col2\" class=\"data row2 col2\" >0.765131</td>\n",
|
779 |
+
" <td id=\"T_750d2_row2_col3\" class=\"data row2 col3\" >0.771518</td>\n",
|
780 |
+
" <td id=\"T_750d2_row2_col4\" class=\"data row2 col4\" >0.772251</td>\n",
|
781 |
+
" <td id=\"T_750d2_row2_col5\" class=\"data row2 col5\" >0.766637</td>\n",
|
782 |
+
" </tr>\n",
|
783 |
+
" <tr>\n",
|
784 |
+
" <th id=\"T_750d2_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
785 |
+
" <td id=\"T_750d2_row3_col0\" class=\"data row3 col0\" >0.667068</td>\n",
|
786 |
+
" <td id=\"T_750d2_row3_col1\" class=\"data row3 col1\" >0.654223</td>\n",
|
787 |
+
" <td id=\"T_750d2_row3_col2\" class=\"data row3 col2\" >0.660922</td>\n",
|
788 |
+
" <td id=\"T_750d2_row3_col3\" class=\"data row3 col3\" >0.675876</td>\n",
|
789 |
+
" <td id=\"T_750d2_row3_col4\" class=\"data row3 col4\" >0.668899</td>\n",
|
790 |
+
" <td id=\"T_750d2_row3_col5\" class=\"data row3 col5\" >0.664795</td>\n",
|
791 |
+
" </tr>\n",
|
792 |
+
" <tr>\n",
|
793 |
+
" <th id=\"T_750d2_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
794 |
+
" <td id=\"T_750d2_row4_col0\" class=\"data row4 col0\" >0.720885</td>\n",
|
795 |
+
" <td id=\"T_750d2_row4_col1\" class=\"data row4 col1\" >0.720836</td>\n",
|
796 |
+
" <td id=\"T_750d2_row4_col2\" class=\"data row4 col2\" >0.717898</td>\n",
|
797 |
+
" <td id=\"T_750d2_row4_col3\" class=\"data row4 col3\" >0.719254</td>\n",
|
798 |
+
" <td id=\"T_750d2_row4_col4\" class=\"data row4 col4\" >0.732333</td>\n",
|
799 |
+
" <td id=\"T_750d2_row4_col5\" class=\"data row4 col5\" >0.717684</td>\n",
|
800 |
+
" </tr>\n",
|
801 |
+
" <tr>\n",
|
802 |
+
" <th id=\"T_750d2_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
803 |
+
" <td id=\"T_750d2_row5_col0\" class=\"data row5 col0\" >0.620728</td>\n",
|
804 |
+
" <td id=\"T_750d2_row5_col1\" class=\"data row5 col1\" >0.598880</td>\n",
|
805 |
+
" <td id=\"T_750d2_row5_col2\" class=\"data row5 col2\" >0.612325</td>\n",
|
806 |
+
" <td id=\"T_750d2_row5_col3\" class=\"data row5 col3\" >0.637433</td>\n",
|
807 |
+
" <td id=\"T_750d2_row5_col4\" class=\"data row5 col4\" >0.615579</td>\n",
|
808 |
+
" <td id=\"T_750d2_row5_col5\" class=\"data row5 col5\" >0.619166</td>\n",
|
809 |
+
" </tr>\n",
|
810 |
+
" <tr>\n",
|
811 |
+
" <th id=\"T_750d2_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
812 |
+
" <td id=\"T_750d2_row6_col0\" class=\"data row6 col0\" >0.852290</td>\n",
|
813 |
+
" <td id=\"T_750d2_row6_col1\" class=\"data row6 col1\" >0.847184</td>\n",
|
814 |
+
" <td id=\"T_750d2_row6_col2\" class=\"data row6 col2\" >0.843733</td>\n",
|
815 |
+
" <td id=\"T_750d2_row6_col3\" class=\"data row6 col3\" >0.851873</td>\n",
|
816 |
+
" <td id=\"T_750d2_row6_col4\" class=\"data row6 col4\" >0.856047</td>\n",
|
817 |
+
" <td id=\"T_750d2_row6_col5\" class=\"data row6 col5\" >0.848834</td>\n",
|
818 |
+
" </tr>\n",
|
819 |
+
" <tr>\n",
|
820 |
+
" <th id=\"T_750d2_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
821 |
+
" <td id=\"T_750d2_row7_col0\" class=\"data row7 col0\" >-0.206596</td>\n",
|
822 |
+
" <td id=\"T_750d2_row7_col1\" class=\"data row7 col1\" >-0.210362</td>\n",
|
823 |
+
" <td id=\"T_750d2_row7_col2\" class=\"data row7 col2\" >-0.211968</td>\n",
|
824 |
+
" <td id=\"T_750d2_row7_col3\" class=\"data row7 col3\" >-0.204214</td>\n",
|
825 |
+
" <td id=\"T_750d2_row7_col4\" class=\"data row7 col4\" >-0.202682</td>\n",
|
826 |
+
" <td id=\"T_750d2_row7_col5\" class=\"data row7 col5\" >-0.208278</td>\n",
|
827 |
+
" </tr>\n",
|
828 |
+
" <tr>\n",
|
829 |
+
" <th id=\"T_750d2_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
830 |
+
" <td id=\"T_750d2_row8_col0\" class=\"data row8 col0\" >-1.668014</td>\n",
|
831 |
+
" <td id=\"T_750d2_row8_col1\" class=\"data row8 col1\" >-1.788896</td>\n",
|
832 |
+
" <td id=\"T_750d2_row8_col2\" class=\"data row8 col2\" >-1.917438</td>\n",
|
833 |
+
" <td id=\"T_750d2_row8_col3\" class=\"data row8 col3\" >-1.662381</td>\n",
|
834 |
+
" <td id=\"T_750d2_row8_col4\" class=\"data row8 col4\" >-1.670358</td>\n",
|
835 |
+
" <td id=\"T_750d2_row8_col5\" class=\"data row8 col5\" >-1.761326</td>\n",
|
836 |
+
" </tr>\n",
|
837 |
+
" </tbody>\n",
|
838 |
+
"</table>\n"
|
839 |
+
],
|
840 |
+
"text/plain": [
|
841 |
+
"<pandas.io.formats.style.Styler at 0x1a1a6bb4820>"
|
842 |
+
]
|
843 |
+
},
|
844 |
+
"metadata": {},
|
845 |
+
"output_type": "display_data"
|
846 |
+
},
|
847 |
+
{
|
848 |
+
"data": {
|
849 |
+
"text/html": [
|
850 |
+
"<style type=\"text/css\">\n",
|
851 |
+
"</style>\n",
|
852 |
+
"<table id=\"T_2d5a3\">\n",
|
853 |
+
" <caption>KNN Validation Scores</caption>\n",
|
854 |
+
" <thead>\n",
|
855 |
+
" <tr>\n",
|
856 |
+
" <th class=\"blank level0\" > </th>\n",
|
857 |
+
" <th id=\"T_2d5a3_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
858 |
+
" <th id=\"T_2d5a3_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
859 |
+
" <th id=\"T_2d5a3_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
860 |
+
" <th id=\"T_2d5a3_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
861 |
+
" <th id=\"T_2d5a3_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
862 |
+
" <th id=\"T_2d5a3_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
863 |
+
" </tr>\n",
|
864 |
+
" </thead>\n",
|
865 |
+
" <tbody>\n",
|
866 |
+
" <tr>\n",
|
867 |
+
" <th id=\"T_2d5a3_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
868 |
+
" <td id=\"T_2d5a3_row0_col0\" class=\"data row0 col0\" >0.010002</td>\n",
|
869 |
+
" <td id=\"T_2d5a3_row0_col1\" class=\"data row0 col1\" >0.011312</td>\n",
|
870 |
+
" <td id=\"T_2d5a3_row0_col2\" class=\"data row0 col2\" >0.011621</td>\n",
|
871 |
+
" <td id=\"T_2d5a3_row0_col3\" class=\"data row0 col3\" >0.013843</td>\n",
|
872 |
+
" <td id=\"T_2d5a3_row0_col4\" class=\"data row0 col4\" >0.011473</td>\n",
|
873 |
+
" <td id=\"T_2d5a3_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
874 |
+
" </tr>\n",
|
875 |
+
" <tr>\n",
|
876 |
+
" <th id=\"T_2d5a3_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
877 |
+
" <td id=\"T_2d5a3_row1_col0\" class=\"data row1 col0\" >1.660269</td>\n",
|
878 |
+
" <td id=\"T_2d5a3_row1_col1\" class=\"data row1 col1\" >1.360570</td>\n",
|
879 |
+
" <td id=\"T_2d5a3_row1_col2\" class=\"data row1 col2\" >1.651296</td>\n",
|
880 |
+
" <td id=\"T_2d5a3_row1_col3\" class=\"data row1 col3\" >1.734129</td>\n",
|
881 |
+
" <td id=\"T_2d5a3_row1_col4\" class=\"data row1 col4\" >1.823339</td>\n",
|
882 |
+
" <td id=\"T_2d5a3_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
883 |
+
" </tr>\n",
|
884 |
+
" <tr>\n",
|
885 |
+
" <th id=\"T_2d5a3_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
886 |
+
" <td id=\"T_2d5a3_row2_col0\" class=\"data row2 col0\" >0.842320</td>\n",
|
887 |
+
" <td id=\"T_2d5a3_row2_col1\" class=\"data row2 col1\" >0.848707</td>\n",
|
888 |
+
" <td id=\"T_2d5a3_row2_col2\" class=\"data row2 col2\" >0.847330</td>\n",
|
889 |
+
" <td id=\"T_2d5a3_row2_col3\" class=\"data row2 col3\" >0.842723</td>\n",
|
890 |
+
" <td id=\"T_2d5a3_row2_col4\" class=\"data row2 col4\" >0.847749</td>\n",
|
891 |
+
" <td id=\"T_2d5a3_row2_col5\" class=\"data row2 col5\" >0.843692</td>\n",
|
892 |
+
" </tr>\n",
|
893 |
+
" <tr>\n",
|
894 |
+
" <th id=\"T_2d5a3_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
895 |
+
" <td id=\"T_2d5a3_row3_col0\" class=\"data row3 col0\" >0.776492</td>\n",
|
896 |
+
" <td id=\"T_2d5a3_row3_col1\" class=\"data row3 col1\" >0.787218</td>\n",
|
897 |
+
" <td id=\"T_2d5a3_row3_col2\" class=\"data row3 col2\" >0.783551</td>\n",
|
898 |
+
" <td id=\"T_2d5a3_row3_col3\" class=\"data row3 col3\" >0.779053</td>\n",
|
899 |
+
" <td id=\"T_2d5a3_row3_col4\" class=\"data row3 col4\" >0.786365</td>\n",
|
900 |
+
" <td id=\"T_2d5a3_row3_col5\" class=\"data row3 col5\" >0.779698</td>\n",
|
901 |
+
" </tr>\n",
|
902 |
+
" <tr>\n",
|
903 |
+
" <th id=\"T_2d5a3_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
904 |
+
" <td id=\"T_2d5a3_row4_col0\" class=\"data row4 col0\" >0.825758</td>\n",
|
905 |
+
" <td id=\"T_2d5a3_row4_col1\" class=\"data row4 col1\" >0.829867</td>\n",
|
906 |
+
" <td id=\"T_2d5a3_row4_col2\" class=\"data row4 col2\" >0.833544</td>\n",
|
907 |
+
" <td id=\"T_2d5a3_row4_col3\" class=\"data row4 col3\" >0.820068</td>\n",
|
908 |
+
" <td id=\"T_2d5a3_row4_col4\" class=\"data row4 col4\" >0.826691</td>\n",
|
909 |
+
" <td id=\"T_2d5a3_row4_col5\" class=\"data row4 col5\" >0.823778</td>\n",
|
910 |
+
" </tr>\n",
|
911 |
+
" <tr>\n",
|
912 |
+
" <th id=\"T_2d5a3_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
913 |
+
" <td id=\"T_2d5a3_row5_col0\" class=\"data row5 col0\" >0.732773</td>\n",
|
914 |
+
" <td id=\"T_2d5a3_row5_col1\" class=\"data row5 col1\" >0.748739</td>\n",
|
915 |
+
" <td id=\"T_2d5a3_row5_col2\" class=\"data row5 col2\" >0.739216</td>\n",
|
916 |
+
" <td id=\"T_2d5a3_row5_col3\" class=\"data row5 col3\" >0.741945</td>\n",
|
917 |
+
" <td id=\"T_2d5a3_row5_col4\" class=\"data row5 col4\" >0.749790</td>\n",
|
918 |
+
" <td id=\"T_2d5a3_row5_col5\" class=\"data row5 col5\" >0.740097</td>\n",
|
919 |
+
" </tr>\n",
|
920 |
+
" <tr>\n",
|
921 |
+
" <th id=\"T_2d5a3_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
922 |
+
" <td id=\"T_2d5a3_row6_col0\" class=\"data row6 col0\" >0.867330</td>\n",
|
923 |
+
" <td id=\"T_2d5a3_row6_col1\" class=\"data row6 col1\" >0.869924</td>\n",
|
924 |
+
" <td id=\"T_2d5a3_row6_col2\" class=\"data row6 col2\" >0.872951</td>\n",
|
925 |
+
" <td id=\"T_2d5a3_row6_col3\" class=\"data row6 col3\" >0.866868</td>\n",
|
926 |
+
" <td id=\"T_2d5a3_row6_col4\" class=\"data row6 col4\" >0.872277</td>\n",
|
927 |
+
" <td id=\"T_2d5a3_row6_col5\" class=\"data row6 col5\" >0.872155</td>\n",
|
928 |
+
" </tr>\n",
|
929 |
+
" <tr>\n",
|
930 |
+
" <th id=\"T_2d5a3_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
931 |
+
" <td id=\"T_2d5a3_row7_col0\" class=\"data row7 col0\" >-0.130989</td>\n",
|
932 |
+
" <td id=\"T_2d5a3_row7_col1\" class=\"data row7 col1\" >-0.127425</td>\n",
|
933 |
+
" <td id=\"T_2d5a3_row7_col2\" class=\"data row7 col2\" >-0.126777</td>\n",
|
934 |
+
" <td id=\"T_2d5a3_row7_col3\" class=\"data row7 col3\" >-0.130655</td>\n",
|
935 |
+
" <td id=\"T_2d5a3_row7_col4\" class=\"data row7 col4\" >-0.127401</td>\n",
|
936 |
+
" <td id=\"T_2d5a3_row7_col5\" class=\"data row7 col5\" >-0.128215</td>\n",
|
937 |
+
" </tr>\n",
|
938 |
+
" <tr>\n",
|
939 |
+
" <th id=\"T_2d5a3_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
940 |
+
" <td id=\"T_2d5a3_row8_col0\" class=\"data row8 col0\" >-2.083997</td>\n",
|
941 |
+
" <td id=\"T_2d5a3_row8_col1\" class=\"data row8 col1\" >-1.959589</td>\n",
|
942 |
+
" <td id=\"T_2d5a3_row8_col2\" class=\"data row8 col2\" >-1.815403</td>\n",
|
943 |
+
" <td id=\"T_2d5a3_row8_col3\" class=\"data row8 col3\" >-2.007178</td>\n",
|
944 |
+
" <td id=\"T_2d5a3_row8_col4\" class=\"data row8 col4\" >-1.929602</td>\n",
|
945 |
+
" <td id=\"T_2d5a3_row8_col5\" class=\"data row8 col5\" >-1.877810</td>\n",
|
946 |
+
" </tr>\n",
|
947 |
+
" </tbody>\n",
|
948 |
+
"</table>\n"
|
949 |
+
],
|
950 |
+
"text/plain": [
|
951 |
+
"<pandas.io.formats.style.Styler at 0x1a1a6bb4820>"
|
952 |
+
]
|
953 |
+
},
|
954 |
+
"metadata": {},
|
955 |
+
"output_type": "display_data"
|
956 |
+
},
|
957 |
+
{
|
958 |
+
"data": {
|
959 |
+
"text/html": [
|
960 |
+
"<style type=\"text/css\">\n",
|
961 |
+
"</style>\n",
|
962 |
+
"<table id=\"T_08606\">\n",
|
963 |
+
" <caption>RF Validation Scores</caption>\n",
|
964 |
+
" <thead>\n",
|
965 |
+
" <tr>\n",
|
966 |
+
" <th class=\"blank level0\" > </th>\n",
|
967 |
+
" <th id=\"T_08606_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
968 |
+
" <th id=\"T_08606_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
969 |
+
" <th id=\"T_08606_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
970 |
+
" <th id=\"T_08606_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
971 |
+
" <th id=\"T_08606_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
972 |
+
" <th id=\"T_08606_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
973 |
+
" </tr>\n",
|
974 |
+
" </thead>\n",
|
975 |
+
" <tbody>\n",
|
976 |
+
" <tr>\n",
|
977 |
+
" <th id=\"T_08606_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
978 |
+
" <td id=\"T_08606_row0_col0\" class=\"data row0 col0\" >4.099665</td>\n",
|
979 |
+
" <td id=\"T_08606_row0_col1\" class=\"data row0 col1\" >4.061200</td>\n",
|
980 |
+
" <td id=\"T_08606_row0_col2\" class=\"data row0 col2\" >4.090116</td>\n",
|
981 |
+
" <td id=\"T_08606_row0_col3\" class=\"data row0 col3\" >4.055705</td>\n",
|
982 |
+
" <td id=\"T_08606_row0_col4\" class=\"data row0 col4\" >4.050387</td>\n",
|
983 |
+
" <td id=\"T_08606_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
984 |
+
" </tr>\n",
|
985 |
+
" <tr>\n",
|
986 |
+
" <th id=\"T_08606_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
987 |
+
" <td id=\"T_08606_row1_col0\" class=\"data row1 col0\" >0.390365</td>\n",
|
988 |
+
" <td id=\"T_08606_row1_col1\" class=\"data row1 col1\" >0.389244</td>\n",
|
989 |
+
" <td id=\"T_08606_row1_col2\" class=\"data row1 col2\" >0.392108</td>\n",
|
990 |
+
" <td id=\"T_08606_row1_col3\" class=\"data row1 col3\" >0.387358</td>\n",
|
991 |
+
" <td id=\"T_08606_row1_col4\" class=\"data row1 col4\" >0.400155</td>\n",
|
992 |
+
" <td id=\"T_08606_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
993 |
+
" </tr>\n",
|
994 |
+
" <tr>\n",
|
995 |
+
" <th id=\"T_08606_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
996 |
+
" <td id=\"T_08606_row2_col0\" class=\"data row2 col0\" >0.856141</td>\n",
|
997 |
+
" <td id=\"T_08606_row2_col1\" class=\"data row2 col1\" >0.859282</td>\n",
|
998 |
+
" <td id=\"T_08606_row2_col2\" class=\"data row2 col2\" >0.861571</td>\n",
|
999 |
+
" <td id=\"T_08606_row2_col3\" class=\"data row2 col3\" >0.853508</td>\n",
|
1000 |
+
" <td id=\"T_08606_row2_col4\" class=\"data row2 col4\" >0.855393</td>\n",
|
1001 |
+
" <td id=\"T_08606_row2_col5\" class=\"data row2 col5\" >0.856152</td>\n",
|
1002 |
+
" </tr>\n",
|
1003 |
+
" <tr>\n",
|
1004 |
+
" <th id=\"T_08606_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
1005 |
+
" <td id=\"T_08606_row3_col0\" class=\"data row3 col0\" >0.800349</td>\n",
|
1006 |
+
" <td id=\"T_08606_row3_col1\" class=\"data row3 col1\" >0.805217</td>\n",
|
1007 |
+
" <td id=\"T_08606_row3_col2\" class=\"data row3 col2\" >0.807681</td>\n",
|
1008 |
+
" <td id=\"T_08606_row3_col3\" class=\"data row3 col3\" >0.798676</td>\n",
|
1009 |
+
" <td id=\"T_08606_row3_col4\" class=\"data row3 col4\" >0.799477</td>\n",
|
1010 |
+
" <td id=\"T_08606_row3_col5\" class=\"data row3 col5\" >0.800623</td>\n",
|
1011 |
+
" </tr>\n",
|
1012 |
+
" <tr>\n",
|
1013 |
+
" <th id=\"T_08606_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
1014 |
+
" <td id=\"T_08606_row4_col0\" class=\"data row4 col0\" >0.831522</td>\n",
|
1015 |
+
" <td id=\"T_08606_row4_col1\" class=\"data row4 col1\" >0.834234</td>\n",
|
1016 |
+
" <td id=\"T_08606_row4_col2\" class=\"data row4 col2\" >0.840194</td>\n",
|
1017 |
+
" <td id=\"T_08606_row4_col3\" class=\"data row4 col3\" >0.821006</td>\n",
|
1018 |
+
" <td id=\"T_08606_row4_col4\" class=\"data row4 col4\" >0.829717</td>\n",
|
1019 |
+
" <td id=\"T_08606_row4_col5\" class=\"data row4 col5\" >0.830547</td>\n",
|
1020 |
+
" </tr>\n",
|
1021 |
+
" <tr>\n",
|
1022 |
+
" <th id=\"T_08606_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
1023 |
+
" <td id=\"T_08606_row5_col0\" class=\"data row5 col0\" >0.771429</td>\n",
|
1024 |
+
" <td id=\"T_08606_row5_col1\" class=\"data row5 col1\" >0.778151</td>\n",
|
1025 |
+
" <td id=\"T_08606_row5_col2\" class=\"data row5 col2\" >0.777591</td>\n",
|
1026 |
+
" <td id=\"T_08606_row5_col3\" class=\"data row5 col3\" >0.777529</td>\n",
|
1027 |
+
" <td id=\"T_08606_row5_col4\" class=\"data row5 col4\" >0.771365</td>\n",
|
1028 |
+
" <td id=\"T_08606_row5_col5\" class=\"data row5 col5\" >0.772781</td>\n",
|
1029 |
+
" </tr>\n",
|
1030 |
+
" <tr>\n",
|
1031 |
+
" <th id=\"T_08606_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
1032 |
+
" <td id=\"T_08606_row6_col0\" class=\"data row6 col0\" >0.890122</td>\n",
|
1033 |
+
" <td id=\"T_08606_row6_col1\" class=\"data row6 col1\" >0.890561</td>\n",
|
1034 |
+
" <td id=\"T_08606_row6_col2\" class=\"data row6 col2\" >0.897321</td>\n",
|
1035 |
+
" <td id=\"T_08606_row6_col3\" class=\"data row6 col3\" >0.887396</td>\n",
|
1036 |
+
" <td id=\"T_08606_row6_col4\" class=\"data row6 col4\" >0.891078</td>\n",
|
1037 |
+
" <td id=\"T_08606_row6_col5\" class=\"data row6 col5\" >0.893466</td>\n",
|
1038 |
+
" </tr>\n",
|
1039 |
+
" <tr>\n",
|
1040 |
+
" <th id=\"T_08606_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
1041 |
+
" <td id=\"T_08606_row7_col0\" class=\"data row7 col0\" >-0.116884</td>\n",
|
1042 |
+
" <td id=\"T_08606_row7_col1\" class=\"data row7 col1\" >-0.114867</td>\n",
|
1043 |
+
" <td id=\"T_08606_row7_col2\" class=\"data row7 col2\" >-0.111343</td>\n",
|
1044 |
+
" <td id=\"T_08606_row7_col3\" class=\"data row7 col3\" >-0.117719</td>\n",
|
1045 |
+
" <td id=\"T_08606_row7_col4\" class=\"data row7 col4\" >-0.116295</td>\n",
|
1046 |
+
" <td id=\"T_08606_row7_col5\" class=\"data row7 col5\" >-0.115285</td>\n",
|
1047 |
+
" </tr>\n",
|
1048 |
+
" <tr>\n",
|
1049 |
+
" <th id=\"T_08606_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
1050 |
+
" <td id=\"T_08606_row8_col0\" class=\"data row8 col0\" >-0.607395</td>\n",
|
1051 |
+
" <td id=\"T_08606_row8_col1\" class=\"data row8 col1\" >-0.579640</td>\n",
|
1052 |
+
" <td id=\"T_08606_row8_col2\" class=\"data row8 col2\" >-0.536542</td>\n",
|
1053 |
+
" <td id=\"T_08606_row8_col3\" class=\"data row8 col3\" >-0.614554</td>\n",
|
1054 |
+
" <td id=\"T_08606_row8_col4\" class=\"data row8 col4\" >-0.631888</td>\n",
|
1055 |
+
" <td id=\"T_08606_row8_col5\" class=\"data row8 col5\" >-0.562042</td>\n",
|
1056 |
+
" </tr>\n",
|
1057 |
+
" </tbody>\n",
|
1058 |
+
"</table>\n"
|
1059 |
+
],
|
1060 |
+
"text/plain": [
|
1061 |
+
"<pandas.io.formats.style.Styler at 0x1a1a6bb4820>"
|
1062 |
+
]
|
1063 |
+
},
|
1064 |
+
"metadata": {},
|
1065 |
+
"output_type": "display_data"
|
1066 |
+
},
|
1067 |
+
{
|
1068 |
+
"data": {
|
1069 |
+
"text/html": [
|
1070 |
+
"<style type=\"text/css\">\n",
|
1071 |
+
"</style>\n",
|
1072 |
+
"<table id=\"T_14134\">\n",
|
1073 |
+
" <caption>GBC Validation Scores</caption>\n",
|
1074 |
+
" <thead>\n",
|
1075 |
+
" <tr>\n",
|
1076 |
+
" <th class=\"blank level0\" > </th>\n",
|
1077 |
+
" <th id=\"T_14134_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
1078 |
+
" <th id=\"T_14134_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
1079 |
+
" <th id=\"T_14134_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
1080 |
+
" <th id=\"T_14134_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
1081 |
+
" <th id=\"T_14134_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
1082 |
+
" <th id=\"T_14134_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
1083 |
+
" </tr>\n",
|
1084 |
+
" </thead>\n",
|
1085 |
+
" <tbody>\n",
|
1086 |
+
" <tr>\n",
|
1087 |
+
" <th id=\"T_14134_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
1088 |
+
" <td id=\"T_14134_row0_col0\" class=\"data row0 col0\" >4.591437</td>\n",
|
1089 |
+
" <td id=\"T_14134_row0_col1\" class=\"data row0 col1\" >4.437213</td>\n",
|
1090 |
+
" <td id=\"T_14134_row0_col2\" class=\"data row0 col2\" >4.121067</td>\n",
|
1091 |
+
" <td id=\"T_14134_row0_col3\" class=\"data row0 col3\" >4.142180</td>\n",
|
1092 |
+
" <td id=\"T_14134_row0_col4\" class=\"data row0 col4\" >4.113901</td>\n",
|
1093 |
+
" <td id=\"T_14134_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
1094 |
+
" </tr>\n",
|
1095 |
+
" <tr>\n",
|
1096 |
+
" <th id=\"T_14134_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
1097 |
+
" <td id=\"T_14134_row1_col0\" class=\"data row1 col0\" >0.055993</td>\n",
|
1098 |
+
" <td id=\"T_14134_row1_col1\" class=\"data row1 col1\" >0.048113</td>\n",
|
1099 |
+
" <td id=\"T_14134_row1_col2\" class=\"data row1 col2\" >0.049492</td>\n",
|
1100 |
+
" <td id=\"T_14134_row1_col3\" class=\"data row1 col3\" >0.050163</td>\n",
|
1101 |
+
" <td id=\"T_14134_row1_col4\" class=\"data row1 col4\" >0.055706</td>\n",
|
1102 |
+
" <td id=\"T_14134_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
1103 |
+
" </tr>\n",
|
1104 |
+
" <tr>\n",
|
1105 |
+
" <th id=\"T_14134_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
1106 |
+
" <td id=\"T_14134_row2_col0\" class=\"data row2 col0\" >0.871113</td>\n",
|
1107 |
+
" <td id=\"T_14134_row2_col1\" class=\"data row2 col1\" >0.873207</td>\n",
|
1108 |
+
" <td id=\"T_14134_row2_col2\" class=\"data row2 col2\" >0.878639</td>\n",
|
1109 |
+
" <td id=\"T_14134_row2_col3\" class=\"data row2 col3\" >0.870052</td>\n",
|
1110 |
+
" <td id=\"T_14134_row2_col4\" class=\"data row2 col4\" >0.870157</td>\n",
|
1111 |
+
" <td id=\"T_14134_row2_col5\" class=\"data row2 col5\" >0.871054</td>\n",
|
1112 |
+
" </tr>\n",
|
1113 |
+
" <tr>\n",
|
1114 |
+
" <th id=\"T_14134_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
1115 |
+
" <td id=\"T_14134_row3_col0\" class=\"data row3 col0\" >0.817169</td>\n",
|
1116 |
+
" <td id=\"T_14134_row3_col1\" class=\"data row3 col1\" >0.820831</td>\n",
|
1117 |
+
" <td id=\"T_14134_row3_col2\" class=\"data row3 col2\" >0.827709</td>\n",
|
1118 |
+
" <td id=\"T_14134_row3_col3\" class=\"data row3 col3\" >0.817043</td>\n",
|
1119 |
+
" <td id=\"T_14134_row3_col4\" class=\"data row3 col4\" >0.817109</td>\n",
|
1120 |
+
" <td id=\"T_14134_row3_col5\" class=\"data row3 col5\" >0.817560</td>\n",
|
1121 |
+
" </tr>\n",
|
1122 |
+
" <tr>\n",
|
1123 |
+
" <th id=\"T_14134_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
1124 |
+
" <td id=\"T_14134_row4_col0\" class=\"data row4 col0\" >0.869744</td>\n",
|
1125 |
+
" <td id=\"T_14134_row4_col1\" class=\"data row4 col1\" >0.869865</td>\n",
|
1126 |
+
" <td id=\"T_14134_row4_col2\" class=\"data row4 col2\" >0.881850</td>\n",
|
1127 |
+
" <td id=\"T_14134_row4_col3\" class=\"data row4 col3\" >0.862166</td>\n",
|
1128 |
+
" <td id=\"T_14134_row4_col4\" class=\"data row4 col4\" >0.862660</td>\n",
|
1129 |
+
" <td id=\"T_14134_row4_col5\" class=\"data row4 col5\" >0.867518</td>\n",
|
1130 |
+
" </tr>\n",
|
1131 |
+
" <tr>\n",
|
1132 |
+
" <th id=\"T_14134_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
1133 |
+
" <td id=\"T_14134_row5_col0\" class=\"data row5 col0\" >0.770588</td>\n",
|
1134 |
+
" <td id=\"T_14134_row5_col1\" class=\"data row5 col1\" >0.777031</td>\n",
|
1135 |
+
" <td id=\"T_14134_row5_col2\" class=\"data row5 col2\" >0.779832</td>\n",
|
1136 |
+
" <td id=\"T_14134_row5_col3\" class=\"data row5 col3\" >0.776408</td>\n",
|
1137 |
+
" <td id=\"T_14134_row5_col4\" class=\"data row5 col4\" >0.776128</td>\n",
|
1138 |
+
" <td id=\"T_14134_row5_col5\" class=\"data row5 col5\" >0.773042</td>\n",
|
1139 |
+
" </tr>\n",
|
1140 |
+
" <tr>\n",
|
1141 |
+
" <th id=\"T_14134_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
1142 |
+
" <td id=\"T_14134_row6_col0\" class=\"data row6 col0\" >0.907041</td>\n",
|
1143 |
+
" <td id=\"T_14134_row6_col1\" class=\"data row6 col1\" >0.908041</td>\n",
|
1144 |
+
" <td id=\"T_14134_row6_col2\" class=\"data row6 col2\" >0.911930</td>\n",
|
1145 |
+
" <td id=\"T_14134_row6_col3\" class=\"data row6 col3\" >0.906283</td>\n",
|
1146 |
+
" <td id=\"T_14134_row6_col4\" class=\"data row6 col4\" >0.909348</td>\n",
|
1147 |
+
" <td id=\"T_14134_row6_col5\" class=\"data row6 col5\" >0.908648</td>\n",
|
1148 |
+
" </tr>\n",
|
1149 |
+
" <tr>\n",
|
1150 |
+
" <th id=\"T_14134_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
1151 |
+
" <td id=\"T_14134_row7_col0\" class=\"data row7 col0\" >-0.105054</td>\n",
|
1152 |
+
" <td id=\"T_14134_row7_col1\" class=\"data row7 col1\" >-0.103463</td>\n",
|
1153 |
+
" <td id=\"T_14134_row7_col2\" class=\"data row7 col2\" >-0.099338</td>\n",
|
1154 |
+
" <td id=\"T_14134_row7_col3\" class=\"data row7 col3\" >-0.104658</td>\n",
|
1155 |
+
" <td id=\"T_14134_row7_col4\" class=\"data row7 col4\" >-0.104459</td>\n",
|
1156 |
+
" <td id=\"T_14134_row7_col5\" class=\"data row7 col5\" >-0.104280</td>\n",
|
1157 |
+
" </tr>\n",
|
1158 |
+
" <tr>\n",
|
1159 |
+
" <th id=\"T_14134_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
1160 |
+
" <td id=\"T_14134_row8_col0\" class=\"data row8 col0\" >-0.352792</td>\n",
|
1161 |
+
" <td id=\"T_14134_row8_col1\" class=\"data row8 col1\" >-0.348499</td>\n",
|
1162 |
+
" <td id=\"T_14134_row8_col2\" class=\"data row8 col2\" >-0.338605</td>\n",
|
1163 |
+
" <td id=\"T_14134_row8_col3\" class=\"data row8 col3\" >-0.351285</td>\n",
|
1164 |
+
" <td id=\"T_14134_row8_col4\" class=\"data row8 col4\" >-0.350193</td>\n",
|
1165 |
+
" <td id=\"T_14134_row8_col5\" class=\"data row8 col5\" >-0.350152</td>\n",
|
1166 |
+
" </tr>\n",
|
1167 |
+
" </tbody>\n",
|
1168 |
+
"</table>\n"
|
1169 |
+
],
|
1170 |
+
"text/plain": [
|
1171 |
+
"<pandas.io.formats.style.Styler at 0x1a1a86dd240>"
|
1172 |
+
]
|
1173 |
+
},
|
1174 |
+
"metadata": {},
|
1175 |
+
"output_type": "display_data"
|
1176 |
+
},
|
1177 |
+
{
|
1178 |
+
"data": {
|
1179 |
+
"text/html": [
|
1180 |
+
"<style type=\"text/css\">\n",
|
1181 |
+
"</style>\n",
|
1182 |
+
"<table id=\"T_25121\">\n",
|
1183 |
+
" <caption>XGB Validation Scores</caption>\n",
|
1184 |
+
" <thead>\n",
|
1185 |
+
" <tr>\n",
|
1186 |
+
" <th class=\"blank level0\" > </th>\n",
|
1187 |
+
" <th id=\"T_25121_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
1188 |
+
" <th id=\"T_25121_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
1189 |
+
" <th id=\"T_25121_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
1190 |
+
" <th id=\"T_25121_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
1191 |
+
" <th id=\"T_25121_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
1192 |
+
" <th id=\"T_25121_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
1193 |
+
" </tr>\n",
|
1194 |
+
" </thead>\n",
|
1195 |
+
" <tbody>\n",
|
1196 |
+
" <tr>\n",
|
1197 |
+
" <th id=\"T_25121_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
1198 |
+
" <td id=\"T_25121_row0_col0\" class=\"data row0 col0\" >3.802029</td>\n",
|
1199 |
+
" <td id=\"T_25121_row0_col1\" class=\"data row0 col1\" >3.036764</td>\n",
|
1200 |
+
" <td id=\"T_25121_row0_col2\" class=\"data row0 col2\" >2.979647</td>\n",
|
1201 |
+
" <td id=\"T_25121_row0_col3\" class=\"data row0 col3\" >2.177232</td>\n",
|
1202 |
+
" <td id=\"T_25121_row0_col4\" class=\"data row0 col4\" >2.287098</td>\n",
|
1203 |
+
" <td id=\"T_25121_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
1204 |
+
" </tr>\n",
|
1205 |
+
" <tr>\n",
|
1206 |
+
" <th id=\"T_25121_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
1207 |
+
" <td id=\"T_25121_row1_col0\" class=\"data row1 col0\" >0.069013</td>\n",
|
1208 |
+
" <td id=\"T_25121_row1_col1\" class=\"data row1 col1\" >0.071819</td>\n",
|
1209 |
+
" <td id=\"T_25121_row1_col2\" class=\"data row1 col2\" >0.049402</td>\n",
|
1210 |
+
" <td id=\"T_25121_row1_col3\" class=\"data row1 col3\" >0.057279</td>\n",
|
1211 |
+
" <td id=\"T_25121_row1_col4\" class=\"data row1 col4\" >0.050020</td>\n",
|
1212 |
+
" <td id=\"T_25121_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
1213 |
+
" </tr>\n",
|
1214 |
+
" <tr>\n",
|
1215 |
+
" <th id=\"T_25121_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
1216 |
+
" <td id=\"T_25121_row2_col0\" class=\"data row2 col0\" >0.860224</td>\n",
|
1217 |
+
" <td id=\"T_25121_row2_col1\" class=\"data row2 col1\" >0.851848</td>\n",
|
1218 |
+
" <td id=\"T_25121_row2_col2\" class=\"data row2 col2\" >0.854136</td>\n",
|
1219 |
+
" <td id=\"T_25121_row2_col3\" class=\"data row2 col3\" >0.853298</td>\n",
|
1220 |
+
" <td id=\"T_25121_row2_col4\" class=\"data row2 col4\" >0.856021</td>\n",
|
1221 |
+
" <td id=\"T_25121_row2_col5\" class=\"data row2 col5\" >0.854344</td>\n",
|
1222 |
+
" </tr>\n",
|
1223 |
+
" <tr>\n",
|
1224 |
+
" <th id=\"T_25121_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
1225 |
+
" <td id=\"T_25121_row3_col0\" class=\"data row3 col0\" >0.814145</td>\n",
|
1226 |
+
" <td id=\"T_25121_row3_col1\" class=\"data row3 col1\" >0.804747</td>\n",
|
1227 |
+
" <td id=\"T_25121_row3_col2\" class=\"data row3 col2\" >0.808259</td>\n",
|
1228 |
+
" <td id=\"T_25121_row3_col3\" class=\"data row3 col3\" >0.807370</td>\n",
|
1229 |
+
" <td id=\"T_25121_row3_col4\" class=\"data row3 col4\" >0.810371</td>\n",
|
1230 |
+
" <td id=\"T_25121_row3_col5\" class=\"data row3 col5\" >0.808283</td>\n",
|
1231 |
+
" </tr>\n",
|
1232 |
+
" <tr>\n",
|
1233 |
+
" <th id=\"T_25121_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
1234 |
+
" <td id=\"T_25121_row4_col0\" class=\"data row4 col0\" >0.809300</td>\n",
|
1235 |
+
" <td id=\"T_25121_row4_col1\" class=\"data row4 col1\" >0.793038</td>\n",
|
1236 |
+
" <td id=\"T_25121_row4_col2\" class=\"data row4 col2\" >0.794587</td>\n",
|
1237 |
+
" <td id=\"T_25121_row4_col3\" class=\"data row4 col3\" >0.792657</td>\n",
|
1238 |
+
" <td id=\"T_25121_row4_col4\" class=\"data row4 col4\" >0.797936</td>\n",
|
1239 |
+
" <td id=\"T_25121_row4_col5\" class=\"data row4 col5\" >0.795443</td>\n",
|
1240 |
+
" </tr>\n",
|
1241 |
+
" <tr>\n",
|
1242 |
+
" <th id=\"T_25121_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
1243 |
+
" <td id=\"T_25121_row5_col0\" class=\"data row5 col0\" >0.819048</td>\n",
|
1244 |
+
" <td id=\"T_25121_row5_col1\" class=\"data row5 col1\" >0.816807</td>\n",
|
1245 |
+
" <td id=\"T_25121_row5_col2\" class=\"data row5 col2\" >0.822409</td>\n",
|
1246 |
+
" <td id=\"T_25121_row5_col3\" class=\"data row5 col3\" >0.822639</td>\n",
|
1247 |
+
" <td id=\"T_25121_row5_col4\" class=\"data row5 col4\" >0.823200</td>\n",
|
1248 |
+
" <td id=\"T_25121_row5_col5\" class=\"data row5 col5\" >0.821545</td>\n",
|
1249 |
+
" </tr>\n",
|
1250 |
+
" <tr>\n",
|
1251 |
+
" <th id=\"T_25121_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
1252 |
+
" <td id=\"T_25121_row6_col0\" class=\"data row6 col0\" >0.908407</td>\n",
|
1253 |
+
" <td id=\"T_25121_row6_col1\" class=\"data row6 col1\" >0.906379</td>\n",
|
1254 |
+
" <td id=\"T_25121_row6_col2\" class=\"data row6 col2\" >0.910833</td>\n",
|
1255 |
+
" <td id=\"T_25121_row6_col3\" class=\"data row6 col3\" >0.907507</td>\n",
|
1256 |
+
" <td id=\"T_25121_row6_col4\" class=\"data row6 col4\" >0.908959</td>\n",
|
1257 |
+
" <td id=\"T_25121_row6_col5\" class=\"data row6 col5\" >0.908681</td>\n",
|
1258 |
+
" </tr>\n",
|
1259 |
+
" <tr>\n",
|
1260 |
+
" <th id=\"T_25121_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
1261 |
+
" <td id=\"T_25121_row7_col0\" class=\"data row7 col0\" >-0.116893</td>\n",
|
1262 |
+
" <td id=\"T_25121_row7_col1\" class=\"data row7 col1\" >-0.119319</td>\n",
|
1263 |
+
" <td id=\"T_25121_row7_col2\" class=\"data row7 col2\" >-0.116034</td>\n",
|
1264 |
+
" <td id=\"T_25121_row7_col3\" class=\"data row7 col3\" >-0.119313</td>\n",
|
1265 |
+
" <td id=\"T_25121_row7_col4\" class=\"data row7 col4\" >-0.118294</td>\n",
|
1266 |
+
" <td id=\"T_25121_row7_col5\" class=\"data row7 col5\" >-0.118266</td>\n",
|
1267 |
+
" </tr>\n",
|
1268 |
+
" <tr>\n",
|
1269 |
+
" <th id=\"T_25121_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
1270 |
+
" <td id=\"T_25121_row8_col0\" class=\"data row8 col0\" >-0.393473</td>\n",
|
1271 |
+
" <td id=\"T_25121_row8_col1\" class=\"data row8 col1\" >-0.395306</td>\n",
|
1272 |
+
" <td id=\"T_25121_row8_col2\" class=\"data row8 col2\" >-0.384403</td>\n",
|
1273 |
+
" <td id=\"T_25121_row8_col3\" class=\"data row8 col3\" >-0.397352</td>\n",
|
1274 |
+
" <td id=\"T_25121_row8_col4\" class=\"data row8 col4\" >-0.394224</td>\n",
|
1275 |
+
" <td id=\"T_25121_row8_col5\" class=\"data row8 col5\" >-0.392001</td>\n",
|
1276 |
+
" </tr>\n",
|
1277 |
+
" </tbody>\n",
|
1278 |
+
"</table>\n"
|
1279 |
+
],
|
1280 |
+
"text/plain": [
|
1281 |
+
"<pandas.io.formats.style.Styler at 0x1a1a8656f80>"
|
1282 |
+
]
|
1283 |
+
},
|
1284 |
+
"metadata": {},
|
1285 |
+
"output_type": "display_data"
|
1286 |
+
}
|
1287 |
+
],
|
1288 |
+
"source": [
|
1289 |
+
"# XGB hyperparameter that deals with unbalanced\n",
|
1290 |
+
"scale_pos_weight = Y.mean()**-1\n",
|
1291 |
+
"\n",
|
1292 |
+
"# Creating the model objects\n",
|
1293 |
+
"cls_lr = LogisticRegression(\n",
|
1294 |
+
" class_weight=\"balanced\", # Hyperparameter to deal with unbalanced output\n",
|
1295 |
+
" random_state=lucky_num)\n",
|
1296 |
+
"# cls_svm = SVC(random_state=lucky_num) # Remove due its resource consumption and worst results\n",
|
1297 |
+
"cls_NB = GaussianNB()\n",
|
1298 |
+
"cls_knn = KNeighborsClassifier()\n",
|
1299 |
+
"cls_rf = RandomForestClassifier(\n",
|
1300 |
+
" random_state=lucky_num,\n",
|
1301 |
+
" class_weight=\"balanced_subsample\") # Hyperparameter to deal with unbalanced output\n",
|
1302 |
+
"cls_gbc = GradientBoostingClassifier(random_state=lucky_num)\n",
|
1303 |
+
"cls_xgb = xgb.XGBClassifier(\n",
|
1304 |
+
" objective=\"binary:logistic\",\n",
|
1305 |
+
" verbose=None,\n",
|
1306 |
+
" random_state=lucky_num,\n",
|
1307 |
+
" scale_pos_weight = scale_pos_weight)\n",
|
1308 |
+
"\n",
|
1309 |
+
"# Lists to iterate on our modeling function\n",
|
1310 |
+
"cls_name = [\"LR\", \"NB\", \"KNN\", \"RF\", \"GBC\", \"XGB\"]\n",
|
1311 |
+
"cls_list = [cls_lr, cls_NB, cls_knn, cls_rf, cls_gbc, cls_xgb]\n",
|
1312 |
+
"\n",
|
1313 |
+
"mdl_summaries = []\n",
|
1314 |
+
"for name, inst in zip(cls_name, cls_list):\n",
|
1315 |
+
" mdl_list = create_model(name, inst)\n",
|
1316 |
+
" mdl_list = [name] + mdl_list\n",
|
1317 |
+
" mdl_summaries.append(mdl_list)\n",
|
1318 |
+
"\n",
|
1319 |
+
"df_mdl = pd.DataFrame(\n",
|
1320 |
+
" mdl_summaries,\n",
|
1321 |
+
" columns=[\n",
|
1322 |
+
" \"model\",\n",
|
1323 |
+
" \"test_accuracy\",\n",
|
1324 |
+
" \"test_f1\",\n",
|
1325 |
+
" \"test_precision\",\n",
|
1326 |
+
" \"test_recall\",\n",
|
1327 |
+
" \"test_roc_auc\",\n",
|
1328 |
+
" \"test_brier\",\n",
|
1329 |
+
" \"test_log_loss\"])"
|
1330 |
+
]
|
1331 |
+
},
|
1332 |
+
{
|
1333 |
+
"cell_type": "code",
|
1334 |
+
"execution_count": 9,
|
1335 |
+
"metadata": {},
|
1336 |
+
"outputs": [
|
1337 |
+
{
|
1338 |
+
"data": {
|
1339 |
+
"text/html": [
|
1340 |
+
"<style type=\"text/css\">\n",
|
1341 |
+
"</style>\n",
|
1342 |
+
"<table id=\"T_3ba63\">\n",
|
1343 |
+
" <caption>Test set validation scores</caption>\n",
|
1344 |
+
" <thead>\n",
|
1345 |
+
" <tr>\n",
|
1346 |
+
" <th class=\"blank level0\" > </th>\n",
|
1347 |
+
" <th id=\"T_3ba63_level0_col0\" class=\"col_heading level0 col0\" >model</th>\n",
|
1348 |
+
" <th id=\"T_3ba63_level0_col1\" class=\"col_heading level0 col1\" >test_accuracy</th>\n",
|
1349 |
+
" <th id=\"T_3ba63_level0_col2\" class=\"col_heading level0 col2\" >test_f1</th>\n",
|
1350 |
+
" <th id=\"T_3ba63_level0_col3\" class=\"col_heading level0 col3\" >test_precision</th>\n",
|
1351 |
+
" <th id=\"T_3ba63_level0_col4\" class=\"col_heading level0 col4\" >test_recall</th>\n",
|
1352 |
+
" <th id=\"T_3ba63_level0_col5\" class=\"col_heading level0 col5\" >test_roc_auc</th>\n",
|
1353 |
+
" <th id=\"T_3ba63_level0_col6\" class=\"col_heading level0 col6\" >test_brier</th>\n",
|
1354 |
+
" <th id=\"T_3ba63_level0_col7\" class=\"col_heading level0 col7\" >test_log_loss</th>\n",
|
1355 |
+
" </tr>\n",
|
1356 |
+
" </thead>\n",
|
1357 |
+
" <tbody>\n",
|
1358 |
+
" <tr>\n",
|
1359 |
+
" <th id=\"T_3ba63_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
|
1360 |
+
" <td id=\"T_3ba63_row0_col0\" class=\"data row0 col0\" >GBC</td>\n",
|
1361 |
+
" <td id=\"T_3ba63_row0_col1\" class=\"data row0 col1\" >0.871054</td>\n",
|
1362 |
+
" <td id=\"T_3ba63_row0_col2\" class=\"data row0 col2\" >0.817560</td>\n",
|
1363 |
+
" <td id=\"T_3ba63_row0_col3\" class=\"data row0 col3\" >0.867518</td>\n",
|
1364 |
+
" <td id=\"T_3ba63_row0_col4\" class=\"data row0 col4\" >0.773042</td>\n",
|
1365 |
+
" <td id=\"T_3ba63_row0_col5\" class=\"data row0 col5\" >0.908648</td>\n",
|
1366 |
+
" <td id=\"T_3ba63_row0_col6\" class=\"data row0 col6\" >0.104280</td>\n",
|
1367 |
+
" <td id=\"T_3ba63_row0_col7\" class=\"data row0 col7\" >0.350152</td>\n",
|
1368 |
+
" </tr>\n",
|
1369 |
+
" <tr>\n",
|
1370 |
+
" <th id=\"T_3ba63_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
|
1371 |
+
" <td id=\"T_3ba63_row1_col0\" class=\"data row1 col0\" >LR</td>\n",
|
1372 |
+
" <td id=\"T_3ba63_row1_col1\" class=\"data row1 col1\" >0.865924</td>\n",
|
1373 |
+
" <td id=\"T_3ba63_row1_col2\" class=\"data row1 col2\" >0.814469</td>\n",
|
1374 |
+
" <td id=\"T_3ba63_row1_col3\" class=\"data row1 col3\" >0.843439</td>\n",
|
1375 |
+
" <td id=\"T_3ba63_row1_col4\" class=\"data row1 col4\" >0.787423</td>\n",
|
1376 |
+
" <td id=\"T_3ba63_row1_col5\" class=\"data row1 col5\" >0.904458</td>\n",
|
1377 |
+
" <td id=\"T_3ba63_row1_col6\" class=\"data row1 col6\" >0.110435</td>\n",
|
1378 |
+
" <td id=\"T_3ba63_row1_col7\" class=\"data row1 col7\" >0.370350</td>\n",
|
1379 |
+
" </tr>\n",
|
1380 |
+
" <tr>\n",
|
1381 |
+
" <th id=\"T_3ba63_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
|
1382 |
+
" <td id=\"T_3ba63_row2_col0\" class=\"data row2 col0\" >XGB</td>\n",
|
1383 |
+
" <td id=\"T_3ba63_row2_col1\" class=\"data row2 col1\" >0.854344</td>\n",
|
1384 |
+
" <td id=\"T_3ba63_row2_col2\" class=\"data row2 col2\" >0.808283</td>\n",
|
1385 |
+
" <td id=\"T_3ba63_row2_col3\" class=\"data row2 col3\" >0.795443</td>\n",
|
1386 |
+
" <td id=\"T_3ba63_row2_col4\" class=\"data row2 col4\" >0.821545</td>\n",
|
1387 |
+
" <td id=\"T_3ba63_row2_col5\" class=\"data row2 col5\" >0.908681</td>\n",
|
1388 |
+
" <td id=\"T_3ba63_row2_col6\" class=\"data row2 col6\" >0.118266</td>\n",
|
1389 |
+
" <td id=\"T_3ba63_row2_col7\" class=\"data row2 col7\" >0.392001</td>\n",
|
1390 |
+
" </tr>\n",
|
1391 |
+
" <tr>\n",
|
1392 |
+
" <th id=\"T_3ba63_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
|
1393 |
+
" <td id=\"T_3ba63_row3_col0\" class=\"data row3 col0\" >RF</td>\n",
|
1394 |
+
" <td id=\"T_3ba63_row3_col1\" class=\"data row3 col1\" >0.856152</td>\n",
|
1395 |
+
" <td id=\"T_3ba63_row3_col2\" class=\"data row3 col2\" >0.800623</td>\n",
|
1396 |
+
" <td id=\"T_3ba63_row3_col3\" class=\"data row3 col3\" >0.830547</td>\n",
|
1397 |
+
" <td id=\"T_3ba63_row3_col4\" class=\"data row3 col4\" >0.772781</td>\n",
|
1398 |
+
" <td id=\"T_3ba63_row3_col5\" class=\"data row3 col5\" >0.893466</td>\n",
|
1399 |
+
" <td id=\"T_3ba63_row3_col6\" class=\"data row3 col6\" >0.115285</td>\n",
|
1400 |
+
" <td id=\"T_3ba63_row3_col7\" class=\"data row3 col7\" >0.562042</td>\n",
|
1401 |
+
" </tr>\n",
|
1402 |
+
" <tr>\n",
|
1403 |
+
" <th id=\"T_3ba63_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
|
1404 |
+
" <td id=\"T_3ba63_row4_col0\" class=\"data row4 col0\" >KNN</td>\n",
|
1405 |
+
" <td id=\"T_3ba63_row4_col1\" class=\"data row4 col1\" >0.843692</td>\n",
|
1406 |
+
" <td id=\"T_3ba63_row4_col2\" class=\"data row4 col2\" >0.779698</td>\n",
|
1407 |
+
" <td id=\"T_3ba63_row4_col3\" class=\"data row4 col3\" >0.823778</td>\n",
|
1408 |
+
" <td id=\"T_3ba63_row4_col4\" class=\"data row4 col4\" >0.740097</td>\n",
|
1409 |
+
" <td id=\"T_3ba63_row4_col5\" class=\"data row4 col5\" >0.872155</td>\n",
|
1410 |
+
" <td id=\"T_3ba63_row4_col6\" class=\"data row4 col6\" >0.128215</td>\n",
|
1411 |
+
" <td id=\"T_3ba63_row4_col7\" class=\"data row4 col7\" >1.877810</td>\n",
|
1412 |
+
" </tr>\n",
|
1413 |
+
" <tr>\n",
|
1414 |
+
" <th id=\"T_3ba63_level0_row5\" class=\"row_heading level0 row5\" >5</th>\n",
|
1415 |
+
" <td id=\"T_3ba63_row5_col0\" class=\"data row5 col0\" >NB</td>\n",
|
1416 |
+
" <td id=\"T_3ba63_row5_col1\" class=\"data row5 col1\" >0.766637</td>\n",
|
1417 |
+
" <td id=\"T_3ba63_row5_col2\" class=\"data row5 col2\" >0.664795</td>\n",
|
1418 |
+
" <td id=\"T_3ba63_row5_col3\" class=\"data row5 col3\" >0.717684</td>\n",
|
1419 |
+
" <td id=\"T_3ba63_row5_col4\" class=\"data row5 col4\" >0.619166</td>\n",
|
1420 |
+
" <td id=\"T_3ba63_row5_col5\" class=\"data row5 col5\" >0.848834</td>\n",
|
1421 |
+
" <td id=\"T_3ba63_row5_col6\" class=\"data row5 col6\" >0.208278</td>\n",
|
1422 |
+
" <td id=\"T_3ba63_row5_col7\" class=\"data row5 col7\" >1.761326</td>\n",
|
1423 |
+
" </tr>\n",
|
1424 |
+
" </tbody>\n",
|
1425 |
+
"</table>\n"
|
1426 |
+
],
|
1427 |
+
"text/plain": [
|
1428 |
+
"<pandas.io.formats.style.Styler at 0x1a1a8656bc0>"
|
1429 |
+
]
|
1430 |
+
},
|
1431 |
+
"metadata": {},
|
1432 |
+
"output_type": "display_data"
|
1433 |
+
}
|
1434 |
+
],
|
1435 |
+
"source": [
|
1436 |
+
"df_mdl.sort_values(\n",
|
1437 |
+
" \"test_f1\",\n",
|
1438 |
+
" ascending=False,\n",
|
1439 |
+
" inplace=True,\n",
|
1440 |
+
" ignore_index=True)\n",
|
1441 |
+
"\n",
|
1442 |
+
"display(df_mdl.style.set_caption(\"Test set validation scores\"))"
|
1443 |
+
]
|
1444 |
+
},
|
1445 |
+
{
|
1446 |
+
"cell_type": "markdown",
|
1447 |
+
"metadata": {},
|
1448 |
+
"source": [
|
1449 |
+
"GBC, LR, XGB and RF preset great results! We have two ways here: hyperparameters tunning or creating a composite model. Let's begin with the composite model.\n"
|
1450 |
+
]
|
1451 |
+
},
|
1452 |
+
{
|
1453 |
+
"cell_type": "code",
|
1454 |
+
"execution_count": 10,
|
1455 |
+
"metadata": {},
|
1456 |
+
"outputs": [
|
1457 |
+
{
|
1458 |
+
"data": {
|
1459 |
+
"text/html": [
|
1460 |
+
"<style type=\"text/css\">\n",
|
1461 |
+
"</style>\n",
|
1462 |
+
"<table id=\"T_04769\">\n",
|
1463 |
+
" <caption>Test set validation scores for Composite Model</caption>\n",
|
1464 |
+
" <thead>\n",
|
1465 |
+
" <tr>\n",
|
1466 |
+
" <th class=\"blank level0\" > </th>\n",
|
1467 |
+
" <th id=\"T_04769_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
1468 |
+
" <th id=\"T_04769_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
1469 |
+
" <th id=\"T_04769_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
1470 |
+
" <th id=\"T_04769_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
1471 |
+
" <th id=\"T_04769_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
1472 |
+
" <th id=\"T_04769_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
1473 |
+
" </tr>\n",
|
1474 |
+
" </thead>\n",
|
1475 |
+
" <tbody>\n",
|
1476 |
+
" <tr>\n",
|
1477 |
+
" <th id=\"T_04769_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
1478 |
+
" <td id=\"T_04769_row0_col0\" class=\"data row0 col0\" >10.109613</td>\n",
|
1479 |
+
" <td id=\"T_04769_row0_col1\" class=\"data row0 col1\" >11.766011</td>\n",
|
1480 |
+
" <td id=\"T_04769_row0_col2\" class=\"data row0 col2\" >11.450818</td>\n",
|
1481 |
+
" <td id=\"T_04769_row0_col3\" class=\"data row0 col3\" >11.737634</td>\n",
|
1482 |
+
" <td id=\"T_04769_row0_col4\" class=\"data row0 col4\" >12.702598</td>\n",
|
1483 |
+
" <td id=\"T_04769_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
1484 |
+
" </tr>\n",
|
1485 |
+
" <tr>\n",
|
1486 |
+
" <th id=\"T_04769_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
1487 |
+
" <td id=\"T_04769_row1_col0\" class=\"data row1 col0\" >0.490518</td>\n",
|
1488 |
+
" <td id=\"T_04769_row1_col1\" class=\"data row1 col1\" >0.532695</td>\n",
|
1489 |
+
" <td id=\"T_04769_row1_col2\" class=\"data row1 col2\" >0.529459</td>\n",
|
1490 |
+
" <td id=\"T_04769_row1_col3\" class=\"data row1 col3\" >0.549051</td>\n",
|
1491 |
+
" <td id=\"T_04769_row1_col4\" class=\"data row1 col4\" >0.586749</td>\n",
|
1492 |
+
" <td id=\"T_04769_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
1493 |
+
" </tr>\n",
|
1494 |
+
" <tr>\n",
|
1495 |
+
" <th id=\"T_04769_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
1496 |
+
" <td id=\"T_04769_row2_col0\" class=\"data row2 col0\" >0.870799</td>\n",
|
1497 |
+
" <td id=\"T_04769_row2_col1\" class=\"data row2 col1\" >0.871532</td>\n",
|
1498 |
+
" <td id=\"T_04769_row2_col2\" class=\"data row2 col2\" >0.875497</td>\n",
|
1499 |
+
" <td id=\"T_04769_row2_col3\" class=\"data row2 col3\" >0.869948</td>\n",
|
1500 |
+
" <td id=\"T_04769_row2_col4\" class=\"data row2 col4\" >0.869215</td>\n",
|
1501 |
+
" <td id=\"T_04769_row2_col5\" class=\"data row2 col5\" >0.869002</td>\n",
|
1502 |
+
" </tr>\n",
|
1503 |
+
" <tr>\n",
|
1504 |
+
" <th id=\"T_04769_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
1505 |
+
" <td id=\"T_04769_row3_col0\" class=\"data row3 col0\" >0.818689</td>\n",
|
1506 |
+
" <td id=\"T_04769_row3_col1\" class=\"data row3 col1\" >0.820797</td>\n",
|
1507 |
+
" <td id=\"T_04769_row3_col2\" class=\"data row3 col2\" >0.826297</td>\n",
|
1508 |
+
" <td id=\"T_04769_row3_col3\" class=\"data row3 col3\" >0.819319</td>\n",
|
1509 |
+
" <td id=\"T_04769_row3_col4\" class=\"data row3 col4\" >0.817531</td>\n",
|
1510 |
+
" <td id=\"T_04769_row3_col5\" class=\"data row3 col5\" >0.817283</td>\n",
|
1511 |
+
" </tr>\n",
|
1512 |
+
" <tr>\n",
|
1513 |
+
" <th id=\"T_04769_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
1514 |
+
" <td id=\"T_04769_row4_col0\" class=\"data row4 col0\" >0.860939</td>\n",
|
1515 |
+
" <td id=\"T_04769_row4_col1\" class=\"data row4 col1\" >0.857492</td>\n",
|
1516 |
+
" <td id=\"T_04769_row4_col2\" class=\"data row4 col2\" >0.863511</td>\n",
|
1517 |
+
" <td id=\"T_04769_row4_col3\" class=\"data row4 col3\" >0.852042</td>\n",
|
1518 |
+
" <td id=\"T_04769_row4_col4\" class=\"data row4 col4\" >0.854090</td>\n",
|
1519 |
+
" <td id=\"T_04769_row4_col5\" class=\"data row4 col5\" >0.853645</td>\n",
|
1520 |
+
" </tr>\n",
|
1521 |
+
" <tr>\n",
|
1522 |
+
" <th id=\"T_04769_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
1523 |
+
" <td id=\"T_04769_row5_col0\" class=\"data row5 col0\" >0.780392</td>\n",
|
1524 |
+
" <td id=\"T_04769_row5_col1\" class=\"data row5 col1\" >0.787115</td>\n",
|
1525 |
+
" <td id=\"T_04769_row5_col2\" class=\"data row5 col2\" >0.792157</td>\n",
|
1526 |
+
" <td id=\"T_04769_row5_col3\" class=\"data row5 col3\" >0.789017</td>\n",
|
1527 |
+
" <td id=\"T_04769_row5_col4\" class=\"data row5 col4\" >0.783973</td>\n",
|
1528 |
+
" <td id=\"T_04769_row5_col5\" class=\"data row5 col5\" >0.783893</td>\n",
|
1529 |
+
" </tr>\n",
|
1530 |
+
" <tr>\n",
|
1531 |
+
" <th id=\"T_04769_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
1532 |
+
" <td id=\"T_04769_row6_col0\" class=\"data row6 col0\" >0.909022</td>\n",
|
1533 |
+
" <td id=\"T_04769_row6_col1\" class=\"data row6 col1\" >0.908890</td>\n",
|
1534 |
+
" <td id=\"T_04769_row6_col2\" class=\"data row6 col2\" >0.912418</td>\n",
|
1535 |
+
" <td id=\"T_04769_row6_col3\" class=\"data row6 col3\" >0.907315</td>\n",
|
1536 |
+
" <td id=\"T_04769_row6_col4\" class=\"data row6 col4\" >0.910340</td>\n",
|
1537 |
+
" <td id=\"T_04769_row6_col5\" class=\"data row6 col5\" >0.910500</td>\n",
|
1538 |
+
" </tr>\n",
|
1539 |
+
" <tr>\n",
|
1540 |
+
" <th id=\"T_04769_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
1541 |
+
" <td id=\"T_04769_row7_col0\" class=\"data row7 col0\" >-0.105818</td>\n",
|
1542 |
+
" <td id=\"T_04769_row7_col1\" class=\"data row7 col1\" >-0.105354</td>\n",
|
1543 |
+
" <td id=\"T_04769_row7_col2\" class=\"data row7 col2\" >-0.101743</td>\n",
|
1544 |
+
" <td id=\"T_04769_row7_col3\" class=\"data row7 col3\" >-0.106567</td>\n",
|
1545 |
+
" <td id=\"T_04769_row7_col4\" class=\"data row7 col4\" >-0.105957</td>\n",
|
1546 |
+
" <td id=\"T_04769_row7_col5\" class=\"data row7 col5\" >-0.105743</td>\n",
|
1547 |
+
" </tr>\n",
|
1548 |
+
" <tr>\n",
|
1549 |
+
" <th id=\"T_04769_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
1550 |
+
" <td id=\"T_04769_row8_col0\" class=\"data row8 col0\" >-0.356051</td>\n",
|
1551 |
+
" <td id=\"T_04769_row8_col1\" class=\"data row8 col1\" >-0.353269</td>\n",
|
1552 |
+
" <td id=\"T_04769_row8_col2\" class=\"data row8 col2\" >-0.344184</td>\n",
|
1553 |
+
" <td id=\"T_04769_row8_col3\" class=\"data row8 col3\" >-0.357062</td>\n",
|
1554 |
+
" <td id=\"T_04769_row8_col4\" class=\"data row8 col4\" >-0.355010</td>\n",
|
1555 |
+
" <td id=\"T_04769_row8_col5\" class=\"data row8 col5\" >-0.353621</td>\n",
|
1556 |
+
" </tr>\n",
|
1557 |
+
" </tbody>\n",
|
1558 |
+
"</table>\n"
|
1559 |
+
],
|
1560 |
+
"text/plain": [
|
1561 |
+
"<pandas.io.formats.style.Styler at 0x1a1859d3520>"
|
1562 |
+
]
|
1563 |
+
},
|
1564 |
+
"metadata": {},
|
1565 |
+
"output_type": "display_data"
|
1566 |
+
}
|
1567 |
+
],
|
1568 |
+
"source": [
|
1569 |
+
"# Selecting the models\n",
|
1570 |
+
"cls_name = [\"GBC\", \"XGB\", \"LR\", \"RF\",]\n",
|
1571 |
+
"cls_list = [cls_gbc, cls_xgb, cls_lr, cls_rf]\n",
|
1572 |
+
"\n",
|
1573 |
+
"# Training the voting classifier\n",
|
1574 |
+
"cls_vot = VotingClassifier([*zip(cls_name, cls_list)], voting=\"soft\")\n",
|
1575 |
+
"cls_vot.fit(X_train, y_train)\n",
|
1576 |
+
"\n",
|
1577 |
+
"# Using cross-validation to evaluate the model fitted\n",
|
1578 |
+
"cls_cross = cross_validate(\n",
|
1579 |
+
" estimator=cls_vot,\n",
|
1580 |
+
" X=X_train,\n",
|
1581 |
+
" y=y_train,\n",
|
1582 |
+
" cv=5,\n",
|
1583 |
+
" scoring=scores)\n",
|
1584 |
+
"\n",
|
1585 |
+
"df_vot = pd.DataFrame.from_dict(cls_cross, orient='index', columns=[\"CV\"+str(i) for i in range(1,6)])\n",
|
1586 |
+
"\n",
|
1587 |
+
"# Calculating score to test set\n",
|
1588 |
+
"accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value = eval_model(cls_vot)\n",
|
1589 |
+
"\n",
|
1590 |
+
"# Filling a dataframe to better presentation\n",
|
1591 |
+
"df_vot.at[\"test_accuracy\", \"TestSet\"] = accurancy\n",
|
1592 |
+
"df_vot.at[\"test_f1\", \"TestSet\"] = f1\n",
|
1593 |
+
"df_vot.at[\"test_recall\", \"TestSet\"] = recall\n",
|
1594 |
+
"df_vot.at[\"test_precision\", \"TestSet\"] = precision\n",
|
1595 |
+
"df_vot.at[\"test_roc_auc\", \"TestSet\"] = roc_auc\n",
|
1596 |
+
"df_vot.at[\"test_neg_brier_score\", \"TestSet\"] = -brier_score\n",
|
1597 |
+
"df_vot.at[\"test_neg_log_loss\", \"TestSet\"] = -log_loss_value\n",
|
1598 |
+
"\n",
|
1599 |
+
"display(df_vot.style.set_caption(\"Test set validation scores for Composite Model\"))"
|
1600 |
+
]
|
1601 |
+
},
|
1602 |
+
{
|
1603 |
+
"cell_type": "markdown",
|
1604 |
+
"metadata": {},
|
1605 |
+
"source": [
|
1606 |
+
"The composite model does not present any evidence of overfitting. For now, we will use it on our app."
|
1607 |
+
]
|
1608 |
+
},
|
1609 |
+
{
|
1610 |
+
"cell_type": "code",
|
1611 |
+
"execution_count": 11,
|
1612 |
+
"metadata": {},
|
1613 |
+
"outputs": [
|
1614 |
+
{
|
1615 |
+
"data": {
|
1616 |
+
"text/plain": [
|
1617 |
+
"['c:\\\\Users\\\\grego\\\\OneDrive\\\\Documentos\\\\Documentos Pessoais\\\\00_DataCamp\\\\09_VSC\\\\poa_car_accidents\\\\poa_car_accidents\\\\model\\\\model_feridos.pkl']"
|
1618 |
+
]
|
1619 |
+
},
|
1620 |
+
"execution_count": 11,
|
1621 |
+
"metadata": {},
|
1622 |
+
"output_type": "execute_result"
|
1623 |
+
}
|
1624 |
+
],
|
1625 |
+
"source": [
|
1626 |
+
"# Saving\n",
|
1627 |
+
"file_name = \"model_\" + output + '.pkl'\n",
|
1628 |
+
"jb.dump(cls_vot, path.join(path.abspath(\"./\"), file_name))"
|
1629 |
+
]
|
1630 |
+
}
|
1631 |
+
],
|
1632 |
+
"metadata": {
|
1633 |
+
"kernelspec": {
|
1634 |
+
"display_name": "Python 3.10.6 64-bit",
|
1635 |
+
"language": "python",
|
1636 |
+
"name": "python3"
|
1637 |
+
},
|
1638 |
+
"language_info": {
|
1639 |
+
"codemirror_mode": {
|
1640 |
+
"name": "ipython",
|
1641 |
+
"version": 3
|
1642 |
+
},
|
1643 |
+
"file_extension": ".py",
|
1644 |
+
"mimetype": "text/x-python",
|
1645 |
+
"name": "python",
|
1646 |
+
"nbconvert_exporter": "python",
|
1647 |
+
"pygments_lexer": "ipython3",
|
1648 |
+
"version": "3.10.6"
|
1649 |
+
},
|
1650 |
+
"orig_nbformat": 4,
|
1651 |
+
"vscode": {
|
1652 |
+
"interpreter": {
|
1653 |
+
"hash": "1372d04dbd71fdc5436c5d6e671c1b9287e750e86143c81b5a7ba0acaf653c5e"
|
1654 |
+
}
|
1655 |
+
}
|
1656 |
+
},
|
1657 |
+
"nbformat": 4,
|
1658 |
+
"nbformat_minor": 2
|
1659 |
+
}
|
model/model_feridos.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ee61d7ec89de8ff0cf49904823a9a892556807449852976ede436fa059ed765a
|
3 |
+
size 248858704
|
model/model_feridos_gr.ipynb
ADDED
@@ -0,0 +1,1659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# 1. Introduction\n",
|
8 |
+
"\n",
|
9 |
+
"This notebook was written to train Porto Alegre Traffic Accidents Data after the first cleaning, processing, and transforming step. This was made in a notebook in the `data` folder. In truth, we will have 3 models.\n",
|
10 |
+
"\n",
|
11 |
+
"1. Predict the probability of injured people.\n",
|
12 |
+
"\n",
|
13 |
+
"2. Predict the probability of seriously injured people.\n",
|
14 |
+
"\n",
|
15 |
+
"3. Predict the probability of dead people in the event or after it.\n",
|
16 |
+
"\n",
|
17 |
+
"The path to training the models will be the same, just make some filtering on data and analyze the results properly."
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "markdown",
|
22 |
+
"metadata": {},
|
23 |
+
"source": [
|
24 |
+
"# 2. Data Loading"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": 5,
|
30 |
+
"metadata": {},
|
31 |
+
"outputs": [
|
32 |
+
{
|
33 |
+
"data": {
|
34 |
+
"text/html": [
|
35 |
+
"<div>\n",
|
36 |
+
"<style scoped>\n",
|
37 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
38 |
+
" vertical-align: middle;\n",
|
39 |
+
" }\n",
|
40 |
+
"\n",
|
41 |
+
" .dataframe tbody tr th {\n",
|
42 |
+
" vertical-align: top;\n",
|
43 |
+
" }\n",
|
44 |
+
"\n",
|
45 |
+
" .dataframe thead th {\n",
|
46 |
+
" text-align: right;\n",
|
47 |
+
" }\n",
|
48 |
+
"</style>\n",
|
49 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
50 |
+
" <thead>\n",
|
51 |
+
" <tr style=\"text-align: right;\">\n",
|
52 |
+
" <th></th>\n",
|
53 |
+
" <th>0</th>\n",
|
54 |
+
" <th>1</th>\n",
|
55 |
+
" <th>2</th>\n",
|
56 |
+
" </tr>\n",
|
57 |
+
" </thead>\n",
|
58 |
+
" <tbody>\n",
|
59 |
+
" <tr>\n",
|
60 |
+
" <th>latitude</th>\n",
|
61 |
+
" <td>-30.009614</td>\n",
|
62 |
+
" <td>-30.0403</td>\n",
|
63 |
+
" <td>-30.069</td>\n",
|
64 |
+
" </tr>\n",
|
65 |
+
" <tr>\n",
|
66 |
+
" <th>longitude</th>\n",
|
67 |
+
" <td>-51.185581</td>\n",
|
68 |
+
" <td>-51.1958</td>\n",
|
69 |
+
" <td>-51.1437</td>\n",
|
70 |
+
" </tr>\n",
|
71 |
+
" <tr>\n",
|
72 |
+
" <th>feridos</th>\n",
|
73 |
+
" <td>True</td>\n",
|
74 |
+
" <td>True</td>\n",
|
75 |
+
" <td>True</td>\n",
|
76 |
+
" </tr>\n",
|
77 |
+
" <tr>\n",
|
78 |
+
" <th>feridos_gr</th>\n",
|
79 |
+
" <td>False</td>\n",
|
80 |
+
" <td>False</td>\n",
|
81 |
+
" <td>False</td>\n",
|
82 |
+
" </tr>\n",
|
83 |
+
" <tr>\n",
|
84 |
+
" <th>fatais</th>\n",
|
85 |
+
" <td>False</td>\n",
|
86 |
+
" <td>False</td>\n",
|
87 |
+
" <td>False</td>\n",
|
88 |
+
" </tr>\n",
|
89 |
+
" <tr>\n",
|
90 |
+
" <th>caminhao</th>\n",
|
91 |
+
" <td>False</td>\n",
|
92 |
+
" <td>False</td>\n",
|
93 |
+
" <td>False</td>\n",
|
94 |
+
" </tr>\n",
|
95 |
+
" <tr>\n",
|
96 |
+
" <th>moto</th>\n",
|
97 |
+
" <td>True</td>\n",
|
98 |
+
" <td>True</td>\n",
|
99 |
+
" <td>False</td>\n",
|
100 |
+
" </tr>\n",
|
101 |
+
" <tr>\n",
|
102 |
+
" <th>cars</th>\n",
|
103 |
+
" <td>True</td>\n",
|
104 |
+
" <td>True</td>\n",
|
105 |
+
" <td>True</td>\n",
|
106 |
+
" </tr>\n",
|
107 |
+
" <tr>\n",
|
108 |
+
" <th>transport</th>\n",
|
109 |
+
" <td>False</td>\n",
|
110 |
+
" <td>False</td>\n",
|
111 |
+
" <td>False</td>\n",
|
112 |
+
" </tr>\n",
|
113 |
+
" <tr>\n",
|
114 |
+
" <th>others</th>\n",
|
115 |
+
" <td>False</td>\n",
|
116 |
+
" <td>False</td>\n",
|
117 |
+
" <td>False</td>\n",
|
118 |
+
" </tr>\n",
|
119 |
+
" <tr>\n",
|
120 |
+
" <th>holiday</th>\n",
|
121 |
+
" <td>False</td>\n",
|
122 |
+
" <td>True</td>\n",
|
123 |
+
" <td>True</td>\n",
|
124 |
+
" </tr>\n",
|
125 |
+
" <tr>\n",
|
126 |
+
" <th>day_1</th>\n",
|
127 |
+
" <td>0</td>\n",
|
128 |
+
" <td>0</td>\n",
|
129 |
+
" <td>0</td>\n",
|
130 |
+
" </tr>\n",
|
131 |
+
" <tr>\n",
|
132 |
+
" <th>day_2</th>\n",
|
133 |
+
" <td>0</td>\n",
|
134 |
+
" <td>0</td>\n",
|
135 |
+
" <td>0</td>\n",
|
136 |
+
" </tr>\n",
|
137 |
+
" <tr>\n",
|
138 |
+
" <th>day_3</th>\n",
|
139 |
+
" <td>0</td>\n",
|
140 |
+
" <td>0</td>\n",
|
141 |
+
" <td>0</td>\n",
|
142 |
+
" </tr>\n",
|
143 |
+
" <tr>\n",
|
144 |
+
" <th>day_4</th>\n",
|
145 |
+
" <td>0</td>\n",
|
146 |
+
" <td>0</td>\n",
|
147 |
+
" <td>0</td>\n",
|
148 |
+
" </tr>\n",
|
149 |
+
" <tr>\n",
|
150 |
+
" <th>day_5</th>\n",
|
151 |
+
" <td>1</td>\n",
|
152 |
+
" <td>0</td>\n",
|
153 |
+
" <td>0</td>\n",
|
154 |
+
" </tr>\n",
|
155 |
+
" <tr>\n",
|
156 |
+
" <th>day_6</th>\n",
|
157 |
+
" <td>0</td>\n",
|
158 |
+
" <td>1</td>\n",
|
159 |
+
" <td>1</td>\n",
|
160 |
+
" </tr>\n",
|
161 |
+
" <tr>\n",
|
162 |
+
" <th>hour_1</th>\n",
|
163 |
+
" <td>0</td>\n",
|
164 |
+
" <td>0</td>\n",
|
165 |
+
" <td>0</td>\n",
|
166 |
+
" </tr>\n",
|
167 |
+
" <tr>\n",
|
168 |
+
" <th>hour_2</th>\n",
|
169 |
+
" <td>0</td>\n",
|
170 |
+
" <td>0</td>\n",
|
171 |
+
" <td>0</td>\n",
|
172 |
+
" </tr>\n",
|
173 |
+
" <tr>\n",
|
174 |
+
" <th>hour_3</th>\n",
|
175 |
+
" <td>0</td>\n",
|
176 |
+
" <td>0</td>\n",
|
177 |
+
" <td>0</td>\n",
|
178 |
+
" </tr>\n",
|
179 |
+
" <tr>\n",
|
180 |
+
" <th>hour_4</th>\n",
|
181 |
+
" <td>0</td>\n",
|
182 |
+
" <td>0</td>\n",
|
183 |
+
" <td>0</td>\n",
|
184 |
+
" </tr>\n",
|
185 |
+
" <tr>\n",
|
186 |
+
" <th>hour_5</th>\n",
|
187 |
+
" <td>0</td>\n",
|
188 |
+
" <td>0</td>\n",
|
189 |
+
" <td>0</td>\n",
|
190 |
+
" </tr>\n",
|
191 |
+
" <tr>\n",
|
192 |
+
" <th>hour_6</th>\n",
|
193 |
+
" <td>0</td>\n",
|
194 |
+
" <td>0</td>\n",
|
195 |
+
" <td>0</td>\n",
|
196 |
+
" </tr>\n",
|
197 |
+
" <tr>\n",
|
198 |
+
" <th>hour_7</th>\n",
|
199 |
+
" <td>0</td>\n",
|
200 |
+
" <td>0</td>\n",
|
201 |
+
" <td>0</td>\n",
|
202 |
+
" </tr>\n",
|
203 |
+
" <tr>\n",
|
204 |
+
" <th>hour_8</th>\n",
|
205 |
+
" <td>0</td>\n",
|
206 |
+
" <td>0</td>\n",
|
207 |
+
" <td>0</td>\n",
|
208 |
+
" </tr>\n",
|
209 |
+
" <tr>\n",
|
210 |
+
" <th>hour_9</th>\n",
|
211 |
+
" <td>0</td>\n",
|
212 |
+
" <td>0</td>\n",
|
213 |
+
" <td>0</td>\n",
|
214 |
+
" </tr>\n",
|
215 |
+
" <tr>\n",
|
216 |
+
" <th>hour_10</th>\n",
|
217 |
+
" <td>0</td>\n",
|
218 |
+
" <td>1</td>\n",
|
219 |
+
" <td>0</td>\n",
|
220 |
+
" </tr>\n",
|
221 |
+
" <tr>\n",
|
222 |
+
" <th>hour_11</th>\n",
|
223 |
+
" <td>0</td>\n",
|
224 |
+
" <td>0</td>\n",
|
225 |
+
" <td>0</td>\n",
|
226 |
+
" </tr>\n",
|
227 |
+
" <tr>\n",
|
228 |
+
" <th>hour_12</th>\n",
|
229 |
+
" <td>0</td>\n",
|
230 |
+
" <td>0</td>\n",
|
231 |
+
" <td>0</td>\n",
|
232 |
+
" </tr>\n",
|
233 |
+
" <tr>\n",
|
234 |
+
" <th>hour_13</th>\n",
|
235 |
+
" <td>0</td>\n",
|
236 |
+
" <td>0</td>\n",
|
237 |
+
" <td>0</td>\n",
|
238 |
+
" </tr>\n",
|
239 |
+
" <tr>\n",
|
240 |
+
" <th>hour_14</th>\n",
|
241 |
+
" <td>0</td>\n",
|
242 |
+
" <td>0</td>\n",
|
243 |
+
" <td>0</td>\n",
|
244 |
+
" </tr>\n",
|
245 |
+
" <tr>\n",
|
246 |
+
" <th>hour_15</th>\n",
|
247 |
+
" <td>0</td>\n",
|
248 |
+
" <td>0</td>\n",
|
249 |
+
" <td>0</td>\n",
|
250 |
+
" </tr>\n",
|
251 |
+
" <tr>\n",
|
252 |
+
" <th>hour_16</th>\n",
|
253 |
+
" <td>0</td>\n",
|
254 |
+
" <td>0</td>\n",
|
255 |
+
" <td>0</td>\n",
|
256 |
+
" </tr>\n",
|
257 |
+
" <tr>\n",
|
258 |
+
" <th>hour_17</th>\n",
|
259 |
+
" <td>0</td>\n",
|
260 |
+
" <td>0</td>\n",
|
261 |
+
" <td>0</td>\n",
|
262 |
+
" </tr>\n",
|
263 |
+
" <tr>\n",
|
264 |
+
" <th>hour_18</th>\n",
|
265 |
+
" <td>0</td>\n",
|
266 |
+
" <td>0</td>\n",
|
267 |
+
" <td>0</td>\n",
|
268 |
+
" </tr>\n",
|
269 |
+
" <tr>\n",
|
270 |
+
" <th>hour_19</th>\n",
|
271 |
+
" <td>1</td>\n",
|
272 |
+
" <td>0</td>\n",
|
273 |
+
" <td>1</td>\n",
|
274 |
+
" </tr>\n",
|
275 |
+
" <tr>\n",
|
276 |
+
" <th>hour_20</th>\n",
|
277 |
+
" <td>0</td>\n",
|
278 |
+
" <td>0</td>\n",
|
279 |
+
" <td>0</td>\n",
|
280 |
+
" </tr>\n",
|
281 |
+
" <tr>\n",
|
282 |
+
" <th>hour_21</th>\n",
|
283 |
+
" <td>0</td>\n",
|
284 |
+
" <td>0</td>\n",
|
285 |
+
" <td>0</td>\n",
|
286 |
+
" </tr>\n",
|
287 |
+
" <tr>\n",
|
288 |
+
" <th>hour_22</th>\n",
|
289 |
+
" <td>0</td>\n",
|
290 |
+
" <td>0</td>\n",
|
291 |
+
" <td>0</td>\n",
|
292 |
+
" </tr>\n",
|
293 |
+
" <tr>\n",
|
294 |
+
" <th>hour_23</th>\n",
|
295 |
+
" <td>0</td>\n",
|
296 |
+
" <td>0</td>\n",
|
297 |
+
" <td>0</td>\n",
|
298 |
+
" </tr>\n",
|
299 |
+
" <tr>\n",
|
300 |
+
" <th>type_ATROPELAMENTO</th>\n",
|
301 |
+
" <td>0</td>\n",
|
302 |
+
" <td>0</td>\n",
|
303 |
+
" <td>1</td>\n",
|
304 |
+
" </tr>\n",
|
305 |
+
" <tr>\n",
|
306 |
+
" <th>type_CHOQUE</th>\n",
|
307 |
+
" <td>0</td>\n",
|
308 |
+
" <td>0</td>\n",
|
309 |
+
" <td>0</td>\n",
|
310 |
+
" </tr>\n",
|
311 |
+
" <tr>\n",
|
312 |
+
" <th>type_COLISÃO</th>\n",
|
313 |
+
" <td>0</td>\n",
|
314 |
+
" <td>0</td>\n",
|
315 |
+
" <td>0</td>\n",
|
316 |
+
" </tr>\n",
|
317 |
+
" <tr>\n",
|
318 |
+
" <th>type_OUTROS</th>\n",
|
319 |
+
" <td>0</td>\n",
|
320 |
+
" <td>0</td>\n",
|
321 |
+
" <td>0</td>\n",
|
322 |
+
" </tr>\n",
|
323 |
+
" </tbody>\n",
|
324 |
+
"</table>\n",
|
325 |
+
"</div>"
|
326 |
+
],
|
327 |
+
"text/plain": [
|
328 |
+
" 0 1 2\n",
|
329 |
+
"latitude -30.009614 -30.0403 -30.069\n",
|
330 |
+
"longitude -51.185581 -51.1958 -51.1437\n",
|
331 |
+
"feridos True True True\n",
|
332 |
+
"feridos_gr False False False\n",
|
333 |
+
"fatais False False False\n",
|
334 |
+
"caminhao False False False\n",
|
335 |
+
"moto True True False\n",
|
336 |
+
"cars True True True\n",
|
337 |
+
"transport False False False\n",
|
338 |
+
"others False False False\n",
|
339 |
+
"holiday False True True\n",
|
340 |
+
"day_1 0 0 0\n",
|
341 |
+
"day_2 0 0 0\n",
|
342 |
+
"day_3 0 0 0\n",
|
343 |
+
"day_4 0 0 0\n",
|
344 |
+
"day_5 1 0 0\n",
|
345 |
+
"day_6 0 1 1\n",
|
346 |
+
"hour_1 0 0 0\n",
|
347 |
+
"hour_2 0 0 0\n",
|
348 |
+
"hour_3 0 0 0\n",
|
349 |
+
"hour_4 0 0 0\n",
|
350 |
+
"hour_5 0 0 0\n",
|
351 |
+
"hour_6 0 0 0\n",
|
352 |
+
"hour_7 0 0 0\n",
|
353 |
+
"hour_8 0 0 0\n",
|
354 |
+
"hour_9 0 0 0\n",
|
355 |
+
"hour_10 0 1 0\n",
|
356 |
+
"hour_11 0 0 0\n",
|
357 |
+
"hour_12 0 0 0\n",
|
358 |
+
"hour_13 0 0 0\n",
|
359 |
+
"hour_14 0 0 0\n",
|
360 |
+
"hour_15 0 0 0\n",
|
361 |
+
"hour_16 0 0 0\n",
|
362 |
+
"hour_17 0 0 0\n",
|
363 |
+
"hour_18 0 0 0\n",
|
364 |
+
"hour_19 1 0 1\n",
|
365 |
+
"hour_20 0 0 0\n",
|
366 |
+
"hour_21 0 0 0\n",
|
367 |
+
"hour_22 0 0 0\n",
|
368 |
+
"hour_23 0 0 0\n",
|
369 |
+
"type_ATROPELAMENTO 0 0 1\n",
|
370 |
+
"type_CHOQUE 0 0 0\n",
|
371 |
+
"type_COLISÃO 0 0 0\n",
|
372 |
+
"type_OUTROS 0 0 0"
|
373 |
+
]
|
374 |
+
},
|
375 |
+
"execution_count": 5,
|
376 |
+
"metadata": {},
|
377 |
+
"output_type": "execute_result"
|
378 |
+
}
|
379 |
+
],
|
380 |
+
"source": [
|
381 |
+
"import os.path as path\n",
|
382 |
+
"from pandas import read_csv\n",
|
383 |
+
"\n",
|
384 |
+
"file_csv = path.abspath(\"../\")\n",
|
385 |
+
"\n",
|
386 |
+
"file_csv = path.join(file_csv, \"data\" ,\"accidents_trans.csv\")\n",
|
387 |
+
"\n",
|
388 |
+
"accidents_trans = read_csv(file_csv)\n",
|
389 |
+
"\n",
|
390 |
+
"accidents_trans.head(3).T"
|
391 |
+
]
|
392 |
+
},
|
393 |
+
{
|
394 |
+
"cell_type": "markdown",
|
395 |
+
"metadata": {},
|
396 |
+
"source": [
|
397 |
+
"# 3. Data Preparation"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "code",
|
402 |
+
"execution_count": 6,
|
403 |
+
"metadata": {},
|
404 |
+
"outputs": [],
|
405 |
+
"source": [
|
406 |
+
"import joblib as jb # Use to save the model to deploy\n",
|
407 |
+
"from sklearn.preprocessing import StandardScaler\n",
|
408 |
+
"from sklearn.model_selection import train_test_split"
|
409 |
+
]
|
410 |
+
},
|
411 |
+
{
|
412 |
+
"cell_type": "code",
|
413 |
+
"execution_count": 7,
|
414 |
+
"metadata": {},
|
415 |
+
"outputs": [
|
416 |
+
{
|
417 |
+
"name": "stdout",
|
418 |
+
"output_type": "stream",
|
419 |
+
"text": [
|
420 |
+
"Our model to predict the probability of feridos_gr will be create with 25497 rows and 41 features.\n"
|
421 |
+
]
|
422 |
+
}
|
423 |
+
],
|
424 |
+
"source": [
|
425 |
+
"outputs = [\"feridos\", \"feridos_gr\", \"fatais\"]\n",
|
426 |
+
"inputs = [col for col in accidents_trans.columns if col not in outputs]\n",
|
427 |
+
"\n",
|
428 |
+
"X = accidents_trans[inputs].copy()\n",
|
429 |
+
"Y = accidents_trans[outputs].copy()\n",
|
430 |
+
"\n",
|
431 |
+
"# Filtering data considering the output\n",
|
432 |
+
"output = \"feridos_gr\"\n",
|
433 |
+
"\n",
|
434 |
+
"if output == \"feridos_gr\":\n",
|
435 |
+
" X = X[Y[\"feridos\"]]\n",
|
436 |
+
" Y = Y.loc[Y[\"feridos\"], \"feridos_gr\"]\n",
|
437 |
+
"elif output == \"fatais\":\n",
|
438 |
+
" X = X[Y[\"feridos_gr\"]]\n",
|
439 |
+
" Y = Y.loc[Y[\"feridos_gr\"], \"fatais\"]\n",
|
440 |
+
"else:\n",
|
441 |
+
" Y = Y[\"feridos\"]\n",
|
442 |
+
"\n",
|
443 |
+
"print(f\"Our model to predict the probability of \" \\\n",
|
444 |
+
" f\"{output} will be create with {X.shape[0]} \" \\\n",
|
445 |
+
" f\"rows and {X.shape[1]} features.\")"
|
446 |
+
]
|
447 |
+
},
|
448 |
+
{
|
449 |
+
"cell_type": "code",
|
450 |
+
"execution_count": 8,
|
451 |
+
"metadata": {},
|
452 |
+
"outputs": [],
|
453 |
+
"source": [
|
454 |
+
"import csv\n",
|
455 |
+
"\n",
|
456 |
+
"with open(\"model_features.csv\", 'w') as f:\n",
|
457 |
+
" writer = csv.writer(f)\n",
|
458 |
+
" writer.writerow(X.columns)"
|
459 |
+
]
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"cell_type": "markdown",
|
463 |
+
"metadata": {},
|
464 |
+
"source": [
|
465 |
+
"Considering that we will use models scaling sensitive, we will need to scale our data first. Beside this, we will need to save our scaler for future use."
|
466 |
+
]
|
467 |
+
},
|
468 |
+
{
|
469 |
+
"cell_type": "code",
|
470 |
+
"execution_count": 9,
|
471 |
+
"metadata": {},
|
472 |
+
"outputs": [
|
473 |
+
{
|
474 |
+
"data": {
|
475 |
+
"text/plain": [
|
476 |
+
"['c:\\\\Users\\\\grego\\\\OneDrive\\\\Documentos\\\\Documentos Pessoais\\\\00_DataCamp\\\\09_VSC\\\\poa_car_accidents\\\\poa_car_accidents\\\\model\\\\scaler_feridos_gr.pkl']"
|
477 |
+
]
|
478 |
+
},
|
479 |
+
"execution_count": 9,
|
480 |
+
"metadata": {},
|
481 |
+
"output_type": "execute_result"
|
482 |
+
}
|
483 |
+
],
|
484 |
+
"source": [
|
485 |
+
"# Setting the random state using my luck number :-)\n",
|
486 |
+
"lucky_num = 7\n",
|
487 |
+
"\n",
|
488 |
+
"# X_train and y_train to train our model\n",
|
489 |
+
"X_train, X_test, y_train, y_test = train_test_split(\n",
|
490 |
+
" X,\n",
|
491 |
+
" Y,\n",
|
492 |
+
" test_size=0.30,\n",
|
493 |
+
" random_state=lucky_num,\n",
|
494 |
+
" shuffle=True, # Used because our data is sort by date\n",
|
495 |
+
" stratify=Y) # Used because our data is unbalanced\n",
|
496 |
+
"\n",
|
497 |
+
"# Scaling\n",
|
498 |
+
"scaler = StandardScaler()\n",
|
499 |
+
"X_train = scaler.fit_transform(X_train)\n",
|
500 |
+
"X_test = scaler.transform(X_test)\n",
|
501 |
+
"\n",
|
502 |
+
"# Saving scaler\n",
|
503 |
+
"file_name = \"scaler_\" + output + '.pkl'\n",
|
504 |
+
"jb.dump(scaler, path.join(path.abspath(\"./\"), file_name))"
|
505 |
+
]
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"cell_type": "markdown",
|
509 |
+
"metadata": {},
|
510 |
+
"source": [
|
511 |
+
"# 4. Data Modeling\n",
|
512 |
+
"\n",
|
513 |
+
"We will create and use cross-validation to evaluate the following models:\n",
|
514 |
+
"\n",
|
515 |
+
"- Logistic Regression;\n",
|
516 |
+
"\n",
|
517 |
+
"- Gaussian Naive Bayes;\n",
|
518 |
+
"\n",
|
519 |
+
"- K Neighbors;\n",
|
520 |
+
"\n",
|
521 |
+
"- Random Forest;\n",
|
522 |
+
"\n",
|
523 |
+
"- Gradient Boosting; and,\n",
|
524 |
+
"\n",
|
525 |
+
"- XGBoost.\n",
|
526 |
+
"\n",
|
527 |
+
"We will use two scores to select and evaluate our models:\n",
|
528 |
+
"\n",
|
529 |
+
"- F1 score: composition between the precision (how much our model correct classify every true label) and recall (how moch our model correct indicate true labels); and,\n",
|
530 |
+
"\n",
|
531 |
+
"- Brier score: average between the correct and the predict probability.\n",
|
532 |
+
"\n",
|
533 |
+
"However, we will see other metrics to support our decision:\n",
|
534 |
+
"\n",
|
535 |
+
"- Accurancy;\n",
|
536 |
+
"\n",
|
537 |
+
"- ROC_AOC; and,\n",
|
538 |
+
"\n",
|
539 |
+
"- Log loss (an other way to quantify the quality of probability predictions).\n",
|
540 |
+
"\n",
|
541 |
+
"And, before you go, we will find for each model if there is a hyperparameter to deal with the unbalanced output."
|
542 |
+
]
|
543 |
+
},
|
544 |
+
{
|
545 |
+
"cell_type": "code",
|
546 |
+
"execution_count": 10,
|
547 |
+
"metadata": {},
|
548 |
+
"outputs": [],
|
549 |
+
"source": [
|
550 |
+
"import pandas as pd\n",
|
551 |
+
"import xgboost as xgb\n",
|
552 |
+
"from sklearn.naive_bayes import GaussianNB\n",
|
553 |
+
"from sklearn.neighbors import KNeighborsClassifier\n",
|
554 |
+
"from sklearn.linear_model import LogisticRegression\n",
|
555 |
+
"from sklearn.model_selection import cross_validate \n",
|
556 |
+
"from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier\n",
|
557 |
+
"from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_auc_score, f1_score, brier_score_loss, log_loss\n",
|
558 |
+
"\n",
|
559 |
+
"scores = [\"accuracy\", \"f1\", \"precision\", \"recall\", \"roc_auc\", \"neg_brier_score\",\"neg_log_loss\"]"
|
560 |
+
]
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"cell_type": "code",
|
564 |
+
"execution_count": 11,
|
565 |
+
"metadata": {},
|
566 |
+
"outputs": [],
|
567 |
+
"source": [
|
568 |
+
"def eval_model(cls) -> tuple:\n",
|
569 |
+
" \"\"\"This function will calculate the metrics\n",
|
570 |
+
" to evaluate a classification model.\n",
|
571 |
+
" \"\"\"\n",
|
572 |
+
" # Predicting labels and probabilities\n",
|
573 |
+
" y_pred = cls.predict(X_test)\n",
|
574 |
+
" y_prob = cls.predict_proba(X_test)[:,1]\n",
|
575 |
+
"\n",
|
576 |
+
" # Calculating scores\n",
|
577 |
+
" accurancy = accuracy_score(y_test, y_pred)\n",
|
578 |
+
" f1 = f1_score(y_test, y_pred)\n",
|
579 |
+
" recall = recall_score(y_test, y_pred)\n",
|
580 |
+
" precision = precision_score(y_test, y_pred)\n",
|
581 |
+
" roc_auc = roc_auc_score(y_test, y_prob) # https://datascience.stackexchange.com/questions/114394/does-roc-auc-different-between-crossval-and-test-set-indicate-overfitting-or-oth\n",
|
582 |
+
" brier_score = brier_score_loss(y_test, y_prob)\n",
|
583 |
+
" log_loss_value = log_loss(y_test, y_prob)\n",
|
584 |
+
"\n",
|
585 |
+
" return accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value\n",
|
586 |
+
"\n",
|
587 |
+
"def create_model(name: str, cls) -> list:\n",
|
588 |
+
" \"\"\"This function will create some models\n",
|
589 |
+
" and return scores to evaluate it.\"\"\"\n",
|
590 |
+
" # Ftting model\n",
|
591 |
+
" cls.fit(X_train, y_train)\n",
|
592 |
+
"\n",
|
593 |
+
" # Using cross-validation to evaluate the model fitted\n",
|
594 |
+
" cls_cross = cross_validate(\n",
|
595 |
+
" estimator=cls,\n",
|
596 |
+
" X=X_train,\n",
|
597 |
+
" y=y_train,\n",
|
598 |
+
" cv=5,\n",
|
599 |
+
" scoring=scores)\n",
|
600 |
+
"\n",
|
601 |
+
" df_cv = pd.DataFrame.from_dict(cls_cross, orient='index', columns=[\"CV\"+str(i) for i in range(1,6)])\n",
|
602 |
+
"\n",
|
603 |
+
" # Calculating score to test set\n",
|
604 |
+
" accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value = eval_model(cls)\n",
|
605 |
+
"\n",
|
606 |
+
" # Filling a dataframe to better presentation\n",
|
607 |
+
" df_cv.at[\"test_accuracy\", \"TestSet\"] = accurancy\n",
|
608 |
+
" df_cv.at[\"test_f1\", \"TestSet\"] = f1\n",
|
609 |
+
" df_cv.at[\"test_recall\", \"TestSet\"] = recall\n",
|
610 |
+
" df_cv.at[\"test_precision\", \"TestSet\"] = precision\n",
|
611 |
+
" df_cv.at[\"test_roc_auc\", \"TestSet\"] = roc_auc\n",
|
612 |
+
" df_cv.at[\"test_neg_brier_score\", \"TestSet\"] = -brier_score\n",
|
613 |
+
" df_cv.at[\"test_neg_log_loss\", \"TestSet\"] = -log_loss_value\n",
|
614 |
+
"\n",
|
615 |
+
" caption = f\"{name} Validation Scores\"\n",
|
616 |
+
"\n",
|
617 |
+
" display(df_cv.style.set_caption(caption))\n",
|
618 |
+
"\n",
|
619 |
+
" return [accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value]"
|
620 |
+
]
|
621 |
+
},
|
622 |
+
{
|
623 |
+
"cell_type": "code",
|
624 |
+
"execution_count": 15,
|
625 |
+
"metadata": {},
|
626 |
+
"outputs": [
|
627 |
+
{
|
628 |
+
"data": {
|
629 |
+
"text/html": [
|
630 |
+
"<style type=\"text/css\">\n",
|
631 |
+
"</style>\n",
|
632 |
+
"<table id=\"T_440e0\">\n",
|
633 |
+
" <caption>LR Validation Scores</caption>\n",
|
634 |
+
" <thead>\n",
|
635 |
+
" <tr>\n",
|
636 |
+
" <th class=\"blank level0\" > </th>\n",
|
637 |
+
" <th id=\"T_440e0_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
638 |
+
" <th id=\"T_440e0_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
639 |
+
" <th id=\"T_440e0_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
640 |
+
" <th id=\"T_440e0_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
641 |
+
" <th id=\"T_440e0_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
642 |
+
" <th id=\"T_440e0_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
643 |
+
" </tr>\n",
|
644 |
+
" </thead>\n",
|
645 |
+
" <tbody>\n",
|
646 |
+
" <tr>\n",
|
647 |
+
" <th id=\"T_440e0_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
648 |
+
" <td id=\"T_440e0_row0_col0\" class=\"data row0 col0\" >0.037210</td>\n",
|
649 |
+
" <td id=\"T_440e0_row0_col1\" class=\"data row0 col1\" >0.031787</td>\n",
|
650 |
+
" <td id=\"T_440e0_row0_col2\" class=\"data row0 col2\" >0.031054</td>\n",
|
651 |
+
" <td id=\"T_440e0_row0_col3\" class=\"data row0 col3\" >0.030418</td>\n",
|
652 |
+
" <td id=\"T_440e0_row0_col4\" class=\"data row0 col4\" >0.031932</td>\n",
|
653 |
+
" <td id=\"T_440e0_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
654 |
+
" </tr>\n",
|
655 |
+
" <tr>\n",
|
656 |
+
" <th id=\"T_440e0_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
657 |
+
" <td id=\"T_440e0_row1_col0\" class=\"data row1 col0\" >0.000000</td>\n",
|
658 |
+
" <td id=\"T_440e0_row1_col1\" class=\"data row1 col1\" >0.008581</td>\n",
|
659 |
+
" <td id=\"T_440e0_row1_col2\" class=\"data row1 col2\" >0.008867</td>\n",
|
660 |
+
" <td id=\"T_440e0_row1_col3\" class=\"data row1 col3\" >0.009067</td>\n",
|
661 |
+
" <td id=\"T_440e0_row1_col4\" class=\"data row1 col4\" >0.008065</td>\n",
|
662 |
+
" <td id=\"T_440e0_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
663 |
+
" </tr>\n",
|
664 |
+
" <tr>\n",
|
665 |
+
" <th id=\"T_440e0_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
666 |
+
" <td id=\"T_440e0_row2_col0\" class=\"data row2 col0\" >0.676751</td>\n",
|
667 |
+
" <td id=\"T_440e0_row2_col1\" class=\"data row2 col1\" >0.650420</td>\n",
|
668 |
+
" <td id=\"T_440e0_row2_col2\" class=\"data row2 col2\" >0.651723</td>\n",
|
669 |
+
" <td id=\"T_440e0_row2_col3\" class=\"data row2 col3\" >0.662370</td>\n",
|
670 |
+
" <td id=\"T_440e0_row2_col4\" class=\"data row2 col4\" >0.641636</td>\n",
|
671 |
+
" <td id=\"T_440e0_row2_col5\" class=\"data row2 col5\" >0.660523</td>\n",
|
672 |
+
" </tr>\n",
|
673 |
+
" <tr>\n",
|
674 |
+
" <th id=\"T_440e0_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
675 |
+
" <td id=\"T_440e0_row3_col0\" class=\"data row3 col0\" >0.437622</td>\n",
|
676 |
+
" <td id=\"T_440e0_row3_col1\" class=\"data row3 col1\" >0.422222</td>\n",
|
677 |
+
" <td id=\"T_440e0_row3_col2\" class=\"data row3 col2\" >0.413403</td>\n",
|
678 |
+
" <td id=\"T_440e0_row3_col3\" class=\"data row3 col3\" >0.424271</td>\n",
|
679 |
+
" <td id=\"T_440e0_row3_col4\" class=\"data row3 col4\" >0.401497</td>\n",
|
680 |
+
" <td id=\"T_440e0_row3_col5\" class=\"data row3 col5\" >0.414431</td>\n",
|
681 |
+
" </tr>\n",
|
682 |
+
" <tr>\n",
|
683 |
+
" <th id=\"T_440e0_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
684 |
+
" <td id=\"T_440e0_row4_col0\" class=\"data row4 col0\" >0.352433</td>\n",
|
685 |
+
" <td id=\"T_440e0_row4_col1\" class=\"data row4 col1\" >0.329957</td>\n",
|
686 |
+
" <td id=\"T_440e0_row4_col2\" class=\"data row4 col2\" >0.326379</td>\n",
|
687 |
+
" <td id=\"T_440e0_row4_col3\" class=\"data row4 col3\" >0.337386</td>\n",
|
688 |
+
" <td id=\"T_440e0_row4_col4\" class=\"data row4 col4\" >0.315673</td>\n",
|
689 |
+
" <td id=\"T_440e0_row4_col5\" class=\"data row4 col5\" >0.331889</td>\n",
|
690 |
+
" </tr>\n",
|
691 |
+
" <tr>\n",
|
692 |
+
" <th id=\"T_440e0_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
693 |
+
" <td id=\"T_440e0_row5_col0\" class=\"data row5 col0\" >0.577121</td>\n",
|
694 |
+
" <td id=\"T_440e0_row5_col1\" class=\"data row5 col1\" >0.586118</td>\n",
|
695 |
+
" <td id=\"T_440e0_row5_col2\" class=\"data row5 col2\" >0.563707</td>\n",
|
696 |
+
" <td id=\"T_440e0_row5_col3\" class=\"data row5 col3\" >0.571429</td>\n",
|
697 |
+
" <td id=\"T_440e0_row5_col4\" class=\"data row5 col4\" >0.551414</td>\n",
|
698 |
+
" <td id=\"T_440e0_row5_col5\" class=\"data row5 col5\" >0.551621</td>\n",
|
699 |
+
" </tr>\n",
|
700 |
+
" <tr>\n",
|
701 |
+
" <th id=\"T_440e0_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
702 |
+
" <td id=\"T_440e0_row6_col0\" class=\"data row6 col0\" >0.688461</td>\n",
|
703 |
+
" <td id=\"T_440e0_row6_col1\" class=\"data row6 col1\" >0.667854</td>\n",
|
704 |
+
" <td id=\"T_440e0_row6_col2\" class=\"data row6 col2\" >0.663118</td>\n",
|
705 |
+
" <td id=\"T_440e0_row6_col3\" class=\"data row6 col3\" >0.671007</td>\n",
|
706 |
+
" <td id=\"T_440e0_row6_col4\" class=\"data row6 col4\" >0.644010</td>\n",
|
707 |
+
" <td id=\"T_440e0_row6_col5\" class=\"data row6 col5\" >0.667540</td>\n",
|
708 |
+
" </tr>\n",
|
709 |
+
" <tr>\n",
|
710 |
+
" <th id=\"T_440e0_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
711 |
+
" <td id=\"T_440e0_row7_col0\" class=\"data row7 col0\" >-0.221720</td>\n",
|
712 |
+
" <td id=\"T_440e0_row7_col1\" class=\"data row7 col1\" >-0.228903</td>\n",
|
713 |
+
" <td id=\"T_440e0_row7_col2\" class=\"data row7 col2\" >-0.228032</td>\n",
|
714 |
+
" <td id=\"T_440e0_row7_col3\" class=\"data row7 col3\" >-0.226401</td>\n",
|
715 |
+
" <td id=\"T_440e0_row7_col4\" class=\"data row7 col4\" >-0.231747</td>\n",
|
716 |
+
" <td id=\"T_440e0_row7_col5\" class=\"data row7 col5\" >-0.225709</td>\n",
|
717 |
+
" </tr>\n",
|
718 |
+
" <tr>\n",
|
719 |
+
" <th id=\"T_440e0_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
720 |
+
" <td id=\"T_440e0_row8_col0\" class=\"data row8 col0\" >-0.635799</td>\n",
|
721 |
+
" <td id=\"T_440e0_row8_col1\" class=\"data row8 col1\" >-0.651183</td>\n",
|
722 |
+
" <td id=\"T_440e0_row8_col2\" class=\"data row8 col2\" >-0.649342</td>\n",
|
723 |
+
" <td id=\"T_440e0_row8_col3\" class=\"data row8 col3\" >-0.646086</td>\n",
|
724 |
+
" <td id=\"T_440e0_row8_col4\" class=\"data row8 col4\" >-0.657217</td>\n",
|
725 |
+
" <td id=\"T_440e0_row8_col5\" class=\"data row8 col5\" >-0.644816</td>\n",
|
726 |
+
" </tr>\n",
|
727 |
+
" </tbody>\n",
|
728 |
+
"</table>\n"
|
729 |
+
],
|
730 |
+
"text/plain": [
|
731 |
+
"<pandas.io.formats.style.Styler at 0x21e08e87ca0>"
|
732 |
+
]
|
733 |
+
},
|
734 |
+
"metadata": {},
|
735 |
+
"output_type": "display_data"
|
736 |
+
},
|
737 |
+
{
|
738 |
+
"data": {
|
739 |
+
"text/html": [
|
740 |
+
"<style type=\"text/css\">\n",
|
741 |
+
"</style>\n",
|
742 |
+
"<table id=\"T_0ee79\">\n",
|
743 |
+
" <caption>NB Validation Scores</caption>\n",
|
744 |
+
" <thead>\n",
|
745 |
+
" <tr>\n",
|
746 |
+
" <th class=\"blank level0\" > </th>\n",
|
747 |
+
" <th id=\"T_0ee79_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
748 |
+
" <th id=\"T_0ee79_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
749 |
+
" <th id=\"T_0ee79_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
750 |
+
" <th id=\"T_0ee79_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
751 |
+
" <th id=\"T_0ee79_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
752 |
+
" <th id=\"T_0ee79_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
753 |
+
" </tr>\n",
|
754 |
+
" </thead>\n",
|
755 |
+
" <tbody>\n",
|
756 |
+
" <tr>\n",
|
757 |
+
" <th id=\"T_0ee79_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
758 |
+
" <td id=\"T_0ee79_row0_col0\" class=\"data row0 col0\" >0.014398</td>\n",
|
759 |
+
" <td id=\"T_0ee79_row0_col1\" class=\"data row0 col1\" >0.009557</td>\n",
|
760 |
+
" <td id=\"T_0ee79_row0_col2\" class=\"data row0 col2\" >0.006821</td>\n",
|
761 |
+
" <td id=\"T_0ee79_row0_col3\" class=\"data row0 col3\" >0.007424</td>\n",
|
762 |
+
" <td id=\"T_0ee79_row0_col4\" class=\"data row0 col4\" >0.012280</td>\n",
|
763 |
+
" <td id=\"T_0ee79_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
764 |
+
" </tr>\n",
|
765 |
+
" <tr>\n",
|
766 |
+
" <th id=\"T_0ee79_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
767 |
+
" <td id=\"T_0ee79_row1_col0\" class=\"data row1 col0\" >0.011809</td>\n",
|
768 |
+
" <td id=\"T_0ee79_row1_col1\" class=\"data row1 col1\" >0.011134</td>\n",
|
769 |
+
" <td id=\"T_0ee79_row1_col2\" class=\"data row1 col2\" >0.017092</td>\n",
|
770 |
+
" <td id=\"T_0ee79_row1_col3\" class=\"data row1 col3\" >0.011901</td>\n",
|
771 |
+
" <td id=\"T_0ee79_row1_col4\" class=\"data row1 col4\" >0.011843</td>\n",
|
772 |
+
" <td id=\"T_0ee79_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
773 |
+
" </tr>\n",
|
774 |
+
" <tr>\n",
|
775 |
+
" <th id=\"T_0ee79_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
776 |
+
" <td id=\"T_0ee79_row2_col0\" class=\"data row2 col0\" >0.713725</td>\n",
|
777 |
+
" <td id=\"T_0ee79_row2_col1\" class=\"data row2 col1\" >0.696639</td>\n",
|
778 |
+
" <td id=\"T_0ee79_row2_col2\" class=\"data row2 col2\" >0.695433</td>\n",
|
779 |
+
" <td id=\"T_0ee79_row2_col3\" class=\"data row2 col3\" >0.695153</td>\n",
|
780 |
+
" <td id=\"T_0ee79_row2_col4\" class=\"data row2 col4\" >0.678902</td>\n",
|
781 |
+
" <td id=\"T_0ee79_row2_col5\" class=\"data row2 col5\" >0.699608</td>\n",
|
782 |
+
" </tr>\n",
|
783 |
+
" <tr>\n",
|
784 |
+
" <th id=\"T_0ee79_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
785 |
+
" <td id=\"T_0ee79_row3_col0\" class=\"data row3 col0\" >0.398115</td>\n",
|
786 |
+
" <td id=\"T_0ee79_row3_col1\" class=\"data row3 col1\" >0.385706</td>\n",
|
787 |
+
" <td id=\"T_0ee79_row3_col2\" class=\"data row3 col2\" >0.386222</td>\n",
|
788 |
+
" <td id=\"T_0ee79_row3_col3\" class=\"data row3 col3\" >0.373993</td>\n",
|
789 |
+
" <td id=\"T_0ee79_row3_col4\" class=\"data row3 col4\" >0.358343</td>\n",
|
790 |
+
" <td id=\"T_0ee79_row3_col5\" class=\"data row3 col5\" >0.381260</td>\n",
|
791 |
+
" </tr>\n",
|
792 |
+
" <tr>\n",
|
793 |
+
" <th id=\"T_0ee79_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
794 |
+
" <td id=\"T_0ee79_row4_col0\" class=\"data row4 col0\" >0.367391</td>\n",
|
795 |
+
" <td id=\"T_0ee79_row4_col1\" class=\"data row4 col1\" >0.345178</td>\n",
|
796 |
+
" <td id=\"T_0ee79_row4_col2\" class=\"data row4 col2\" >0.344064</td>\n",
|
797 |
+
" <td id=\"T_0ee79_row4_col3\" class=\"data row4 col3\" >0.338189</td>\n",
|
798 |
+
" <td id=\"T_0ee79_row4_col4\" class=\"data row4 col4\" >0.317460</td>\n",
|
799 |
+
" <td id=\"T_0ee79_row4_col5\" class=\"data row4 col5\" >0.345703</td>\n",
|
800 |
+
" </tr>\n",
|
801 |
+
" <tr>\n",
|
802 |
+
" <th id=\"T_0ee79_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
803 |
+
" <td id=\"T_0ee79_row5_col0\" class=\"data row5 col0\" >0.434447</td>\n",
|
804 |
+
" <td id=\"T_0ee79_row5_col1\" class=\"data row5 col1\" >0.437018</td>\n",
|
805 |
+
" <td id=\"T_0ee79_row5_col2\" class=\"data row5 col2\" >0.440154</td>\n",
|
806 |
+
" <td id=\"T_0ee79_row5_col3\" class=\"data row5 col3\" >0.418275</td>\n",
|
807 |
+
" <td id=\"T_0ee79_row5_col4\" class=\"data row5 col4\" >0.411311</td>\n",
|
808 |
+
" <td id=\"T_0ee79_row5_col5\" class=\"data row5 col5\" >0.424970</td>\n",
|
809 |
+
" </tr>\n",
|
810 |
+
" <tr>\n",
|
811 |
+
" <th id=\"T_0ee79_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
812 |
+
" <td id=\"T_0ee79_row6_col0\" class=\"data row6 col0\" >0.658495</td>\n",
|
813 |
+
" <td id=\"T_0ee79_row6_col1\" class=\"data row6 col1\" >0.637091</td>\n",
|
814 |
+
" <td id=\"T_0ee79_row6_col2\" class=\"data row6 col2\" >0.633805</td>\n",
|
815 |
+
" <td id=\"T_0ee79_row6_col3\" class=\"data row6 col3\" >0.633238</td>\n",
|
816 |
+
" <td id=\"T_0ee79_row6_col4\" class=\"data row6 col4\" >0.609990</td>\n",
|
817 |
+
" <td id=\"T_0ee79_row6_col5\" class=\"data row6 col5\" >0.634031</td>\n",
|
818 |
+
" </tr>\n",
|
819 |
+
" <tr>\n",
|
820 |
+
" <th id=\"T_0ee79_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
821 |
+
" <td id=\"T_0ee79_row7_col0\" class=\"data row7 col0\" >-0.251232</td>\n",
|
822 |
+
" <td id=\"T_0ee79_row7_col1\" class=\"data row7 col1\" >-0.260890</td>\n",
|
823 |
+
" <td id=\"T_0ee79_row7_col2\" class=\"data row7 col2\" >-0.271004</td>\n",
|
824 |
+
" <td id=\"T_0ee79_row7_col3\" class=\"data row7 col3\" >-0.273609</td>\n",
|
825 |
+
" <td id=\"T_0ee79_row7_col4\" class=\"data row7 col4\" >-0.285054</td>\n",
|
826 |
+
" <td id=\"T_0ee79_row7_col5\" class=\"data row7 col5\" >-0.268468</td>\n",
|
827 |
+
" </tr>\n",
|
828 |
+
" <tr>\n",
|
829 |
+
" <th id=\"T_0ee79_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
830 |
+
" <td id=\"T_0ee79_row8_col0\" class=\"data row8 col0\" >-1.412893</td>\n",
|
831 |
+
" <td id=\"T_0ee79_row8_col1\" class=\"data row8 col1\" >-1.627295</td>\n",
|
832 |
+
" <td id=\"T_0ee79_row8_col2\" class=\"data row8 col2\" >-1.745289</td>\n",
|
833 |
+
" <td id=\"T_0ee79_row8_col3\" class=\"data row8 col3\" >-1.752608</td>\n",
|
834 |
+
" <td id=\"T_0ee79_row8_col4\" class=\"data row8 col4\" >-1.950351</td>\n",
|
835 |
+
" <td id=\"T_0ee79_row8_col5\" class=\"data row8 col5\" >-1.659029</td>\n",
|
836 |
+
" </tr>\n",
|
837 |
+
" </tbody>\n",
|
838 |
+
"</table>\n"
|
839 |
+
],
|
840 |
+
"text/plain": [
|
841 |
+
"<pandas.io.formats.style.Styler at 0x21e09deee00>"
|
842 |
+
]
|
843 |
+
},
|
844 |
+
"metadata": {},
|
845 |
+
"output_type": "display_data"
|
846 |
+
},
|
847 |
+
{
|
848 |
+
"data": {
|
849 |
+
"text/html": [
|
850 |
+
"<style type=\"text/css\">\n",
|
851 |
+
"</style>\n",
|
852 |
+
"<table id=\"T_4933a\">\n",
|
853 |
+
" <caption>KNN Validation Scores</caption>\n",
|
854 |
+
" <thead>\n",
|
855 |
+
" <tr>\n",
|
856 |
+
" <th class=\"blank level0\" > </th>\n",
|
857 |
+
" <th id=\"T_4933a_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
858 |
+
" <th id=\"T_4933a_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
859 |
+
" <th id=\"T_4933a_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
860 |
+
" <th id=\"T_4933a_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
861 |
+
" <th id=\"T_4933a_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
862 |
+
" <th id=\"T_4933a_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
863 |
+
" </tr>\n",
|
864 |
+
" </thead>\n",
|
865 |
+
" <tbody>\n",
|
866 |
+
" <tr>\n",
|
867 |
+
" <th id=\"T_4933a_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
868 |
+
" <td id=\"T_4933a_row0_col0\" class=\"data row0 col0\" >0.000000</td>\n",
|
869 |
+
" <td id=\"T_4933a_row0_col1\" class=\"data row0 col1\" >0.002192</td>\n",
|
870 |
+
" <td id=\"T_4933a_row0_col2\" class=\"data row0 col2\" >0.004006</td>\n",
|
871 |
+
" <td id=\"T_4933a_row0_col3\" class=\"data row0 col3\" >0.004514</td>\n",
|
872 |
+
" <td id=\"T_4933a_row0_col4\" class=\"data row0 col4\" >0.006648</td>\n",
|
873 |
+
" <td id=\"T_4933a_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
874 |
+
" </tr>\n",
|
875 |
+
" <tr>\n",
|
876 |
+
" <th id=\"T_4933a_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
877 |
+
" <td id=\"T_4933a_row1_col0\" class=\"data row1 col0\" >0.241018</td>\n",
|
878 |
+
" <td id=\"T_4933a_row1_col1\" class=\"data row1 col1\" >0.238426</td>\n",
|
879 |
+
" <td id=\"T_4933a_row1_col2\" class=\"data row1 col2\" >0.234827</td>\n",
|
880 |
+
" <td id=\"T_4933a_row1_col3\" class=\"data row1 col3\" >0.250730</td>\n",
|
881 |
+
" <td id=\"T_4933a_row1_col4\" class=\"data row1 col4\" >0.503909</td>\n",
|
882 |
+
" <td id=\"T_4933a_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
883 |
+
" </tr>\n",
|
884 |
+
" <tr>\n",
|
885 |
+
" <th id=\"T_4933a_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
886 |
+
" <td id=\"T_4933a_row2_col0\" class=\"data row2 col0\" >0.749860</td>\n",
|
887 |
+
" <td id=\"T_4933a_row2_col1\" class=\"data row2 col1\" >0.757703</td>\n",
|
888 |
+
" <td id=\"T_4933a_row2_col2\" class=\"data row2 col2\" >0.750911</td>\n",
|
889 |
+
" <td id=\"T_4933a_row2_col3\" class=\"data row2 col3\" >0.753713</td>\n",
|
890 |
+
" <td id=\"T_4933a_row2_col4\" class=\"data row2 col4\" >0.746708</td>\n",
|
891 |
+
" <td id=\"T_4933a_row2_col5\" class=\"data row2 col5\" >0.754379</td>\n",
|
892 |
+
" </tr>\n",
|
893 |
+
" <tr>\n",
|
894 |
+
" <th id=\"T_4933a_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
895 |
+
" <td id=\"T_4933a_row3_col0\" class=\"data row3 col0\" >0.203390</td>\n",
|
896 |
+
" <td id=\"T_4933a_row3_col1\" class=\"data row3 col1\" >0.214351</td>\n",
|
897 |
+
" <td id=\"T_4933a_row3_col2\" class=\"data row3 col2\" >0.201258</td>\n",
|
898 |
+
" <td id=\"T_4933a_row3_col3\" class=\"data row3 col3\" >0.208821</td>\n",
|
899 |
+
" <td id=\"T_4933a_row3_col4\" class=\"data row3 col4\" >0.208406</td>\n",
|
900 |
+
" <td id=\"T_4933a_row3_col5\" class=\"data row3 col5\" >0.214136</td>\n",
|
901 |
+
" </tr>\n",
|
902 |
+
" <tr>\n",
|
903 |
+
" <th id=\"T_4933a_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
904 |
+
" <td id=\"T_4933a_row4_col0\" class=\"data row4 col0\" >0.332362</td>\n",
|
905 |
+
" <td id=\"T_4933a_row4_col1\" class=\"data row4 col1\" >0.365325</td>\n",
|
906 |
+
" <td id=\"T_4933a_row4_col2\" class=\"data row4 col2\" >0.333333</td>\n",
|
907 |
+
" <td id=\"T_4933a_row4_col3\" class=\"data row4 col3\" >0.347305</td>\n",
|
908 |
+
" <td id=\"T_4933a_row4_col4\" class=\"data row4 col4\" >0.326923</td>\n",
|
909 |
+
" <td id=\"T_4933a_row4_col5\" class=\"data row4 col5\" >0.353103</td>\n",
|
910 |
+
" </tr>\n",
|
911 |
+
" <tr>\n",
|
912 |
+
" <th id=\"T_4933a_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
913 |
+
" <td id=\"T_4933a_row5_col0\" class=\"data row5 col0\" >0.146530</td>\n",
|
914 |
+
" <td id=\"T_4933a_row5_col1\" class=\"data row5 col1\" >0.151671</td>\n",
|
915 |
+
" <td id=\"T_4933a_row5_col2\" class=\"data row5 col2\" >0.144144</td>\n",
|
916 |
+
" <td id=\"T_4933a_row5_col3\" class=\"data row5 col3\" >0.149292</td>\n",
|
917 |
+
" <td id=\"T_4933a_row5_col4\" class=\"data row5 col4\" >0.152956</td>\n",
|
918 |
+
" <td id=\"T_4933a_row5_col5\" class=\"data row5 col5\" >0.153661</td>\n",
|
919 |
+
" </tr>\n",
|
920 |
+
" <tr>\n",
|
921 |
+
" <th id=\"T_4933a_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
922 |
+
" <td id=\"T_4933a_row6_col0\" class=\"data row6 col0\" >0.583018</td>\n",
|
923 |
+
" <td id=\"T_4933a_row6_col1\" class=\"data row6 col1\" >0.588834</td>\n",
|
924 |
+
" <td id=\"T_4933a_row6_col2\" class=\"data row6 col2\" >0.575426</td>\n",
|
925 |
+
" <td id=\"T_4933a_row6_col3\" class=\"data row6 col3\" >0.575821</td>\n",
|
926 |
+
" <td id=\"T_4933a_row6_col4\" class=\"data row6 col4\" >0.573316</td>\n",
|
927 |
+
" <td id=\"T_4933a_row6_col5\" class=\"data row6 col5\" >0.582555</td>\n",
|
928 |
+
" </tr>\n",
|
929 |
+
" <tr>\n",
|
930 |
+
" <th id=\"T_4933a_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
931 |
+
" <td id=\"T_4933a_row7_col0\" class=\"data row7 col0\" >-0.190980</td>\n",
|
932 |
+
" <td id=\"T_4933a_row7_col1\" class=\"data row7 col1\" >-0.186599</td>\n",
|
933 |
+
" <td id=\"T_4933a_row7_col2\" class=\"data row7 col2\" >-0.191325</td>\n",
|
934 |
+
" <td id=\"T_4933a_row7_col3\" class=\"data row7 col3\" >-0.190608</td>\n",
|
935 |
+
" <td id=\"T_4933a_row7_col4\" class=\"data row7 col4\" >-0.193769</td>\n",
|
936 |
+
" <td id=\"T_4933a_row7_col5\" class=\"data row7 col5\" >-0.189077</td>\n",
|
937 |
+
" </tr>\n",
|
938 |
+
" <tr>\n",
|
939 |
+
" <th id=\"T_4933a_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
940 |
+
" <td id=\"T_4933a_row8_col0\" class=\"data row8 col0\" >-2.500209</td>\n",
|
941 |
+
" <td id=\"T_4933a_row8_col1\" class=\"data row8 col1\" >-2.097705</td>\n",
|
942 |
+
" <td id=\"T_4933a_row8_col2\" class=\"data row8 col2\" >-2.376728</td>\n",
|
943 |
+
" <td id=\"T_4933a_row8_col3\" class=\"data row8 col3\" >-2.528142</td>\n",
|
944 |
+
" <td id=\"T_4933a_row8_col4\" class=\"data row8 col4\" >-2.570247</td>\n",
|
945 |
+
" <td id=\"T_4933a_row8_col5\" class=\"data row8 col5\" >-2.428060</td>\n",
|
946 |
+
" </tr>\n",
|
947 |
+
" </tbody>\n",
|
948 |
+
"</table>\n"
|
949 |
+
],
|
950 |
+
"text/plain": [
|
951 |
+
"<pandas.io.formats.style.Styler at 0x21e093db520>"
|
952 |
+
]
|
953 |
+
},
|
954 |
+
"metadata": {},
|
955 |
+
"output_type": "display_data"
|
956 |
+
},
|
957 |
+
{
|
958 |
+
"data": {
|
959 |
+
"text/html": [
|
960 |
+
"<style type=\"text/css\">\n",
|
961 |
+
"</style>\n",
|
962 |
+
"<table id=\"T_a0470\">\n",
|
963 |
+
" <caption>RF Validation Scores</caption>\n",
|
964 |
+
" <thead>\n",
|
965 |
+
" <tr>\n",
|
966 |
+
" <th class=\"blank level0\" > </th>\n",
|
967 |
+
" <th id=\"T_a0470_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
968 |
+
" <th id=\"T_a0470_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
969 |
+
" <th id=\"T_a0470_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
970 |
+
" <th id=\"T_a0470_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
971 |
+
" <th id=\"T_a0470_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
972 |
+
" <th id=\"T_a0470_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
973 |
+
" </tr>\n",
|
974 |
+
" </thead>\n",
|
975 |
+
" <tbody>\n",
|
976 |
+
" <tr>\n",
|
977 |
+
" <th id=\"T_a0470_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
978 |
+
" <td id=\"T_a0470_row0_col0\" class=\"data row0 col0\" >1.522204</td>\n",
|
979 |
+
" <td id=\"T_a0470_row0_col1\" class=\"data row0 col1\" >1.489951</td>\n",
|
980 |
+
" <td id=\"T_a0470_row0_col2\" class=\"data row0 col2\" >1.485231</td>\n",
|
981 |
+
" <td id=\"T_a0470_row0_col3\" class=\"data row0 col3\" >1.517533</td>\n",
|
982 |
+
" <td id=\"T_a0470_row0_col4\" class=\"data row0 col4\" >1.491071</td>\n",
|
983 |
+
" <td id=\"T_a0470_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
984 |
+
" </tr>\n",
|
985 |
+
" <tr>\n",
|
986 |
+
" <th id=\"T_a0470_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
987 |
+
" <td id=\"T_a0470_row1_col0\" class=\"data row1 col0\" >0.152138</td>\n",
|
988 |
+
" <td id=\"T_a0470_row1_col1\" class=\"data row1 col1\" >0.145954</td>\n",
|
989 |
+
" <td id=\"T_a0470_row1_col2\" class=\"data row1 col2\" >0.156016</td>\n",
|
990 |
+
" <td id=\"T_a0470_row1_col3\" class=\"data row1 col3\" >0.150085</td>\n",
|
991 |
+
" <td id=\"T_a0470_row1_col4\" class=\"data row1 col4\" >0.141980</td>\n",
|
992 |
+
" <td id=\"T_a0470_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
993 |
+
" </tr>\n",
|
994 |
+
" <tr>\n",
|
995 |
+
" <th id=\"T_a0470_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
996 |
+
" <td id=\"T_a0470_row2_col0\" class=\"data row2 col0\" >0.738655</td>\n",
|
997 |
+
" <td id=\"T_a0470_row2_col1\" class=\"data row2 col1\" >0.746779</td>\n",
|
998 |
+
" <td id=\"T_a0470_row2_col2\" class=\"data row2 col2\" >0.740544</td>\n",
|
999 |
+
" <td id=\"T_a0470_row2_col3\" class=\"data row2 col3\" >0.734099</td>\n",
|
1000 |
+
" <td id=\"T_a0470_row2_col4\" class=\"data row2 col4\" >0.737461</td>\n",
|
1001 |
+
" <td id=\"T_a0470_row2_col5\" class=\"data row2 col5\" >0.733595</td>\n",
|
1002 |
+
" </tr>\n",
|
1003 |
+
" <tr>\n",
|
1004 |
+
" <th id=\"T_a0470_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
1005 |
+
" <td id=\"T_a0470_row3_col0\" class=\"data row3 col0\" >0.240846</td>\n",
|
1006 |
+
" <td id=\"T_a0470_row3_col1\" class=\"data row3 col1\" >0.262643</td>\n",
|
1007 |
+
" <td id=\"T_a0470_row3_col2\" class=\"data row3 col2\" >0.227045</td>\n",
|
1008 |
+
" <td id=\"T_a0470_row3_col3\" class=\"data row3 col3\" >0.203191</td>\n",
|
1009 |
+
" <td id=\"T_a0470_row3_col4\" class=\"data row3 col4\" >0.244964</td>\n",
|
1010 |
+
" <td id=\"T_a0470_row3_col5\" class=\"data row3 col5\" >0.242379</td>\n",
|
1011 |
+
" </tr>\n",
|
1012 |
+
" <tr>\n",
|
1013 |
+
" <th id=\"T_a0470_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
1014 |
+
" <td id=\"T_a0470_row4_col0\" class=\"data row4 col0\" >0.328160</td>\n",
|
1015 |
+
" <td id=\"T_a0470_row4_col1\" class=\"data row4 col1\" >0.359375</td>\n",
|
1016 |
+
" <td id=\"T_a0470_row4_col2\" class=\"data row4 col2\" >0.323040</td>\n",
|
1017 |
+
" <td id=\"T_a0470_row4_col3\" class=\"data row4 col3\" >0.292271</td>\n",
|
1018 |
+
" <td id=\"T_a0470_row4_col4\" class=\"data row4 col4\" >0.328294</td>\n",
|
1019 |
+
" <td id=\"T_a0470_row4_col5\" class=\"data row4 col5\" >0.318359</td>\n",
|
1020 |
+
" </tr>\n",
|
1021 |
+
" <tr>\n",
|
1022 |
+
" <th id=\"T_a0470_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
1023 |
+
" <td id=\"T_a0470_row5_col0\" class=\"data row5 col0\" >0.190231</td>\n",
|
1024 |
+
" <td id=\"T_a0470_row5_col1\" class=\"data row5 col1\" >0.206941</td>\n",
|
1025 |
+
" <td id=\"T_a0470_row5_col2\" class=\"data row5 col2\" >0.175032</td>\n",
|
1026 |
+
" <td id=\"T_a0470_row5_col3\" class=\"data row5 col3\" >0.155727</td>\n",
|
1027 |
+
" <td id=\"T_a0470_row5_col4\" class=\"data row5 col4\" >0.195373</td>\n",
|
1028 |
+
" <td id=\"T_a0470_row5_col5\" class=\"data row5 col5\" >0.195678</td>\n",
|
1029 |
+
" </tr>\n",
|
1030 |
+
" <tr>\n",
|
1031 |
+
" <th id=\"T_a0470_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
1032 |
+
" <td id=\"T_a0470_row6_col0\" class=\"data row6 col0\" >0.616397</td>\n",
|
1033 |
+
" <td id=\"T_a0470_row6_col1\" class=\"data row6 col1\" >0.614756</td>\n",
|
1034 |
+
" <td id=\"T_a0470_row6_col2\" class=\"data row6 col2\" >0.597688</td>\n",
|
1035 |
+
" <td id=\"T_a0470_row6_col3\" class=\"data row6 col3\" >0.602290</td>\n",
|
1036 |
+
" <td id=\"T_a0470_row6_col4\" class=\"data row6 col4\" >0.596575</td>\n",
|
1037 |
+
" <td id=\"T_a0470_row6_col5\" class=\"data row6 col5\" >0.601436</td>\n",
|
1038 |
+
" </tr>\n",
|
1039 |
+
" <tr>\n",
|
1040 |
+
" <th id=\"T_a0470_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
1041 |
+
" <td id=\"T_a0470_row7_col0\" class=\"data row7 col0\" >-0.184849</td>\n",
|
1042 |
+
" <td id=\"T_a0470_row7_col1\" class=\"data row7 col1\" >-0.182776</td>\n",
|
1043 |
+
" <td id=\"T_a0470_row7_col2\" class=\"data row7 col2\" >-0.186704</td>\n",
|
1044 |
+
" <td id=\"T_a0470_row7_col3\" class=\"data row7 col3\" >-0.187369</td>\n",
|
1045 |
+
" <td id=\"T_a0470_row7_col4\" class=\"data row7 col4\" >-0.189693</td>\n",
|
1046 |
+
" <td id=\"T_a0470_row7_col5\" class=\"data row7 col5\" >-0.188506</td>\n",
|
1047 |
+
" </tr>\n",
|
1048 |
+
" <tr>\n",
|
1049 |
+
" <th id=\"T_a0470_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
1050 |
+
" <td id=\"T_a0470_row8_col0\" class=\"data row8 col0\" >-0.711816</td>\n",
|
1051 |
+
" <td id=\"T_a0470_row8_col1\" class=\"data row8 col1\" >-0.673727</td>\n",
|
1052 |
+
" <td id=\"T_a0470_row8_col2\" class=\"data row8 col2\" >-0.766073</td>\n",
|
1053 |
+
" <td id=\"T_a0470_row8_col3\" class=\"data row8 col3\" >-0.719537</td>\n",
|
1054 |
+
" <td id=\"T_a0470_row8_col4\" class=\"data row8 col4\" >-0.775028</td>\n",
|
1055 |
+
" <td id=\"T_a0470_row8_col5\" class=\"data row8 col5\" >-0.743245</td>\n",
|
1056 |
+
" </tr>\n",
|
1057 |
+
" </tbody>\n",
|
1058 |
+
"</table>\n"
|
1059 |
+
],
|
1060 |
+
"text/plain": [
|
1061 |
+
"<pandas.io.formats.style.Styler at 0x21e6237c3a0>"
|
1062 |
+
]
|
1063 |
+
},
|
1064 |
+
"metadata": {},
|
1065 |
+
"output_type": "display_data"
|
1066 |
+
},
|
1067 |
+
{
|
1068 |
+
"data": {
|
1069 |
+
"text/html": [
|
1070 |
+
"<style type=\"text/css\">\n",
|
1071 |
+
"</style>\n",
|
1072 |
+
"<table id=\"T_87742\">\n",
|
1073 |
+
" <caption>GBC Validation Scores</caption>\n",
|
1074 |
+
" <thead>\n",
|
1075 |
+
" <tr>\n",
|
1076 |
+
" <th class=\"blank level0\" > </th>\n",
|
1077 |
+
" <th id=\"T_87742_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
1078 |
+
" <th id=\"T_87742_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
1079 |
+
" <th id=\"T_87742_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
1080 |
+
" <th id=\"T_87742_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
1081 |
+
" <th id=\"T_87742_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
1082 |
+
" <th id=\"T_87742_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
1083 |
+
" </tr>\n",
|
1084 |
+
" </thead>\n",
|
1085 |
+
" <tbody>\n",
|
1086 |
+
" <tr>\n",
|
1087 |
+
" <th id=\"T_87742_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
1088 |
+
" <td id=\"T_87742_row0_col0\" class=\"data row0 col0\" >1.292703</td>\n",
|
1089 |
+
" <td id=\"T_87742_row0_col1\" class=\"data row0 col1\" >1.329037</td>\n",
|
1090 |
+
" <td id=\"T_87742_row0_col2\" class=\"data row0 col2\" >1.315473</td>\n",
|
1091 |
+
" <td id=\"T_87742_row0_col3\" class=\"data row0 col3\" >1.299452</td>\n",
|
1092 |
+
" <td id=\"T_87742_row0_col4\" class=\"data row0 col4\" >1.313950</td>\n",
|
1093 |
+
" <td id=\"T_87742_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
1094 |
+
" </tr>\n",
|
1095 |
+
" <tr>\n",
|
1096 |
+
" <th id=\"T_87742_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
1097 |
+
" <td id=\"T_87742_row1_col0\" class=\"data row1 col0\" >0.014560</td>\n",
|
1098 |
+
" <td id=\"T_87742_row1_col1\" class=\"data row1 col1\" >0.026127</td>\n",
|
1099 |
+
" <td id=\"T_87742_row1_col2\" class=\"data row1 col2\" >0.018812</td>\n",
|
1100 |
+
" <td id=\"T_87742_row1_col3\" class=\"data row1 col3\" >0.015999</td>\n",
|
1101 |
+
" <td id=\"T_87742_row1_col4\" class=\"data row1 col4\" >0.020753</td>\n",
|
1102 |
+
" <td id=\"T_87742_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
1103 |
+
" </tr>\n",
|
1104 |
+
" <tr>\n",
|
1105 |
+
" <th id=\"T_87742_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
1106 |
+
" <td id=\"T_87742_row2_col0\" class=\"data row2 col0\" >0.783193</td>\n",
|
1107 |
+
" <td id=\"T_87742_row2_col1\" class=\"data row2 col1\" >0.777871</td>\n",
|
1108 |
+
" <td id=\"T_87742_row2_col2\" class=\"data row2 col2\" >0.782292</td>\n",
|
1109 |
+
" <td id=\"T_87742_row2_col3\" class=\"data row2 col3\" >0.779770</td>\n",
|
1110 |
+
" <td id=\"T_87742_row2_col4\" class=\"data row2 col4\" >0.780050</td>\n",
|
1111 |
+
" <td id=\"T_87742_row2_col5\" class=\"data row2 col5\" >0.781830</td>\n",
|
1112 |
+
" </tr>\n",
|
1113 |
+
" <tr>\n",
|
1114 |
+
" <th id=\"T_87742_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
1115 |
+
" <td id=\"T_87742_row3_col0\" class=\"data row3 col0\" >0.112385</td>\n",
|
1116 |
+
" <td id=\"T_87742_row3_col1\" class=\"data row3 col1\" >0.072515</td>\n",
|
1117 |
+
" <td id=\"T_87742_row3_col2\" class=\"data row3 col2\" >0.091228</td>\n",
|
1118 |
+
" <td id=\"T_87742_row3_col3\" class=\"data row3 col3\" >0.075294</td>\n",
|
1119 |
+
" <td id=\"T_87742_row3_col4\" class=\"data row3 col4\" >0.081871</td>\n",
|
1120 |
+
" <td id=\"T_87742_row3_col5\" class=\"data row3 col5\" >0.085479</td>\n",
|
1121 |
+
" </tr>\n",
|
1122 |
+
" <tr>\n",
|
1123 |
+
" <th id=\"T_87742_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
1124 |
+
" <td id=\"T_87742_row4_col0\" class=\"data row4 col0\" >0.521277</td>\n",
|
1125 |
+
" <td id=\"T_87742_row4_col1\" class=\"data row4 col1\" >0.402597</td>\n",
|
1126 |
+
" <td id=\"T_87742_row4_col2\" class=\"data row4 col2\" >0.500000</td>\n",
|
1127 |
+
" <td id=\"T_87742_row4_col3\" class=\"data row4 col3\" >0.438356</td>\n",
|
1128 |
+
" <td id=\"T_87742_row4_col4\" class=\"data row4 col4\" >0.454545</td>\n",
|
1129 |
+
" <td id=\"T_87742_row4_col5\" class=\"data row4 col5\" >0.490566</td>\n",
|
1130 |
+
" </tr>\n",
|
1131 |
+
" <tr>\n",
|
1132 |
+
" <th id=\"T_87742_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
1133 |
+
" <td id=\"T_87742_row5_col0\" class=\"data row5 col0\" >0.062982</td>\n",
|
1134 |
+
" <td id=\"T_87742_row5_col1\" class=\"data row5 col1\" >0.039846</td>\n",
|
1135 |
+
" <td id=\"T_87742_row5_col2\" class=\"data row5 col2\" >0.050193</td>\n",
|
1136 |
+
" <td id=\"T_87742_row5_col3\" class=\"data row5 col3\" >0.041184</td>\n",
|
1137 |
+
" <td id=\"T_87742_row5_col4\" class=\"data row5 col4\" >0.044987</td>\n",
|
1138 |
+
" <td id=\"T_87742_row5_col5\" class=\"data row5 col5\" >0.046819</td>\n",
|
1139 |
+
" </tr>\n",
|
1140 |
+
" <tr>\n",
|
1141 |
+
" <th id=\"T_87742_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
1142 |
+
" <td id=\"T_87742_row6_col0\" class=\"data row6 col0\" >0.684278</td>\n",
|
1143 |
+
" <td id=\"T_87742_row6_col1\" class=\"data row6 col1\" >0.665449</td>\n",
|
1144 |
+
" <td id=\"T_87742_row6_col2\" class=\"data row6 col2\" >0.670804</td>\n",
|
1145 |
+
" <td id=\"T_87742_row6_col3\" class=\"data row6 col3\" >0.671522</td>\n",
|
1146 |
+
" <td id=\"T_87742_row6_col4\" class=\"data row6 col4\" >0.654303</td>\n",
|
1147 |
+
" <td id=\"T_87742_row6_col5\" class=\"data row6 col5\" >0.669660</td>\n",
|
1148 |
+
" </tr>\n",
|
1149 |
+
" <tr>\n",
|
1150 |
+
" <th id=\"T_87742_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
1151 |
+
" <td id=\"T_87742_row7_col0\" class=\"data row7 col0\" >-0.157619</td>\n",
|
1152 |
+
" <td id=\"T_87742_row7_col1\" class=\"data row7 col1\" >-0.160471</td>\n",
|
1153 |
+
" <td id=\"T_87742_row7_col2\" class=\"data row7 col2\" >-0.159413</td>\n",
|
1154 |
+
" <td id=\"T_87742_row7_col3\" class=\"data row7 col3\" >-0.159018</td>\n",
|
1155 |
+
" <td id=\"T_87742_row7_col4\" class=\"data row7 col4\" >-0.161594</td>\n",
|
1156 |
+
" <td id=\"T_87742_row7_col5\" class=\"data row7 col5\" >-0.159691</td>\n",
|
1157 |
+
" </tr>\n",
|
1158 |
+
" <tr>\n",
|
1159 |
+
" <th id=\"T_87742_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
1160 |
+
" <td id=\"T_87742_row8_col0\" class=\"data row8 col0\" >-0.488690</td>\n",
|
1161 |
+
" <td id=\"T_87742_row8_col1\" class=\"data row8 col1\" >-0.495751</td>\n",
|
1162 |
+
" <td id=\"T_87742_row8_col2\" class=\"data row8 col2\" >-0.493124</td>\n",
|
1163 |
+
" <td id=\"T_87742_row8_col3\" class=\"data row8 col3\" >-0.492051</td>\n",
|
1164 |
+
" <td id=\"T_87742_row8_col4\" class=\"data row8 col4\" >-0.498777</td>\n",
|
1165 |
+
" <td id=\"T_87742_row8_col5\" class=\"data row8 col5\" >-0.493648</td>\n",
|
1166 |
+
" </tr>\n",
|
1167 |
+
" </tbody>\n",
|
1168 |
+
"</table>\n"
|
1169 |
+
],
|
1170 |
+
"text/plain": [
|
1171 |
+
"<pandas.io.formats.style.Styler at 0x21e0ab71420>"
|
1172 |
+
]
|
1173 |
+
},
|
1174 |
+
"metadata": {},
|
1175 |
+
"output_type": "display_data"
|
1176 |
+
},
|
1177 |
+
{
|
1178 |
+
"data": {
|
1179 |
+
"text/html": [
|
1180 |
+
"<style type=\"text/css\">\n",
|
1181 |
+
"</style>\n",
|
1182 |
+
"<table id=\"T_0e44f\">\n",
|
1183 |
+
" <caption>XGB Validation Scores</caption>\n",
|
1184 |
+
" <thead>\n",
|
1185 |
+
" <tr>\n",
|
1186 |
+
" <th class=\"blank level0\" > </th>\n",
|
1187 |
+
" <th id=\"T_0e44f_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
1188 |
+
" <th id=\"T_0e44f_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
1189 |
+
" <th id=\"T_0e44f_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
1190 |
+
" <th id=\"T_0e44f_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
1191 |
+
" <th id=\"T_0e44f_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
1192 |
+
" <th id=\"T_0e44f_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
1193 |
+
" </tr>\n",
|
1194 |
+
" </thead>\n",
|
1195 |
+
" <tbody>\n",
|
1196 |
+
" <tr>\n",
|
1197 |
+
" <th id=\"T_0e44f_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
1198 |
+
" <td id=\"T_0e44f_row0_col0\" class=\"data row0 col0\" >0.643858</td>\n",
|
1199 |
+
" <td id=\"T_0e44f_row0_col1\" class=\"data row0 col1\" >0.640153</td>\n",
|
1200 |
+
" <td id=\"T_0e44f_row0_col2\" class=\"data row0 col2\" >0.677121</td>\n",
|
1201 |
+
" <td id=\"T_0e44f_row0_col3\" class=\"data row0 col3\" >0.634137</td>\n",
|
1202 |
+
" <td id=\"T_0e44f_row0_col4\" class=\"data row0 col4\" >0.669338</td>\n",
|
1203 |
+
" <td id=\"T_0e44f_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
1204 |
+
" </tr>\n",
|
1205 |
+
" <tr>\n",
|
1206 |
+
" <th id=\"T_0e44f_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
1207 |
+
" <td id=\"T_0e44f_row1_col0\" class=\"data row1 col0\" >0.023605</td>\n",
|
1208 |
+
" <td id=\"T_0e44f_row1_col1\" class=\"data row1 col1\" >0.016412</td>\n",
|
1209 |
+
" <td id=\"T_0e44f_row1_col2\" class=\"data row1 col2\" >0.020805</td>\n",
|
1210 |
+
" <td id=\"T_0e44f_row1_col3\" class=\"data row1 col3\" >0.015040</td>\n",
|
1211 |
+
" <td id=\"T_0e44f_row1_col4\" class=\"data row1 col4\" >0.028892</td>\n",
|
1212 |
+
" <td id=\"T_0e44f_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
1213 |
+
" </tr>\n",
|
1214 |
+
" <tr>\n",
|
1215 |
+
" <th id=\"T_0e44f_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
1216 |
+
" <td id=\"T_0e44f_row2_col0\" class=\"data row2 col0\" >0.620168</td>\n",
|
1217 |
+
" <td id=\"T_0e44f_row2_col1\" class=\"data row2 col1\" >0.612605</td>\n",
|
1218 |
+
" <td id=\"T_0e44f_row2_col2\" class=\"data row2 col2\" >0.604371</td>\n",
|
1219 |
+
" <td id=\"T_0e44f_row2_col3\" class=\"data row2 col3\" >0.628748</td>\n",
|
1220 |
+
" <td id=\"T_0e44f_row2_col4\" class=\"data row2 col4\" >0.620062</td>\n",
|
1221 |
+
" <td id=\"T_0e44f_row2_col5\" class=\"data row2 col5\" >0.619216</td>\n",
|
1222 |
+
" </tr>\n",
|
1223 |
+
" <tr>\n",
|
1224 |
+
" <th id=\"T_0e44f_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
1225 |
+
" <td id=\"T_0e44f_row3_col0\" class=\"data row3 col0\" >0.397869</td>\n",
|
1226 |
+
" <td id=\"T_0e44f_row3_col1\" class=\"data row3 col1\" >0.386696</td>\n",
|
1227 |
+
" <td id=\"T_0e44f_row3_col2\" class=\"data row3 col2\" >0.377974</td>\n",
|
1228 |
+
" <td id=\"T_0e44f_row3_col3\" class=\"data row3 col3\" >0.386290</td>\n",
|
1229 |
+
" <td id=\"T_0e44f_row3_col4\" class=\"data row3 col4\" >0.383076</td>\n",
|
1230 |
+
" <td id=\"T_0e44f_row3_col5\" class=\"data row3 col5\" >0.387639</td>\n",
|
1231 |
+
" </tr>\n",
|
1232 |
+
" <tr>\n",
|
1233 |
+
" <th id=\"T_0e44f_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
1234 |
+
" <td id=\"T_0e44f_row4_col0\" class=\"data row4 col0\" >0.303935</td>\n",
|
1235 |
+
" <td id=\"T_0e44f_row4_col1\" class=\"data row4 col1\" >0.295193</td>\n",
|
1236 |
+
" <td id=\"T_0e44f_row4_col2\" class=\"data row4 col2\" >0.287341</td>\n",
|
1237 |
+
" <td id=\"T_0e44f_row4_col3\" class=\"data row4 col3\" >0.301737</td>\n",
|
1238 |
+
" <td id=\"T_0e44f_row4_col4\" class=\"data row4 col4\" >0.296479</td>\n",
|
1239 |
+
" <td id=\"T_0e44f_row4_col5\" class=\"data row4 col5\" >0.298285</td>\n",
|
1240 |
+
" </tr>\n",
|
1241 |
+
" <tr>\n",
|
1242 |
+
" <th id=\"T_0e44f_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
1243 |
+
" <td id=\"T_0e44f_row5_col0\" class=\"data row5 col0\" >0.575835</td>\n",
|
1244 |
+
" <td id=\"T_0e44f_row5_col1\" class=\"data row5 col1\" >0.560411</td>\n",
|
1245 |
+
" <td id=\"T_0e44f_row5_col2\" class=\"data row5 col2\" >0.552124</td>\n",
|
1246 |
+
" <td id=\"T_0e44f_row5_col3\" class=\"data row5 col3\" >0.536680</td>\n",
|
1247 |
+
" <td id=\"T_0e44f_row5_col4\" class=\"data row5 col4\" >0.541131</td>\n",
|
1248 |
+
" <td id=\"T_0e44f_row5_col5\" class=\"data row5 col5\" >0.553421</td>\n",
|
1249 |
+
" </tr>\n",
|
1250 |
+
" <tr>\n",
|
1251 |
+
" <th id=\"T_0e44f_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
1252 |
+
" <td id=\"T_0e44f_row6_col0\" class=\"data row6 col0\" >0.640317</td>\n",
|
1253 |
+
" <td id=\"T_0e44f_row6_col1\" class=\"data row6 col1\" >0.626210</td>\n",
|
1254 |
+
" <td id=\"T_0e44f_row6_col2\" class=\"data row6 col2\" >0.619924</td>\n",
|
1255 |
+
" <td id=\"T_0e44f_row6_col3\" class=\"data row6 col3\" >0.627078</td>\n",
|
1256 |
+
" <td id=\"T_0e44f_row6_col4\" class=\"data row6 col4\" >0.620747</td>\n",
|
1257 |
+
" <td id=\"T_0e44f_row6_col5\" class=\"data row6 col5\" >0.630028</td>\n",
|
1258 |
+
" </tr>\n",
|
1259 |
+
" <tr>\n",
|
1260 |
+
" <th id=\"T_0e44f_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
1261 |
+
" <td id=\"T_0e44f_row7_col0\" class=\"data row7 col0\" >-0.231951</td>\n",
|
1262 |
+
" <td id=\"T_0e44f_row7_col1\" class=\"data row7 col1\" >-0.238560</td>\n",
|
1263 |
+
" <td id=\"T_0e44f_row7_col2\" class=\"data row7 col2\" >-0.237950</td>\n",
|
1264 |
+
" <td id=\"T_0e44f_row7_col3\" class=\"data row7 col3\" >-0.233297</td>\n",
|
1265 |
+
" <td id=\"T_0e44f_row7_col4\" class=\"data row7 col4\" >-0.239657</td>\n",
|
1266 |
+
" <td id=\"T_0e44f_row7_col5\" class=\"data row7 col5\" >-0.236789</td>\n",
|
1267 |
+
" </tr>\n",
|
1268 |
+
" <tr>\n",
|
1269 |
+
" <th id=\"T_0e44f_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
1270 |
+
" <td id=\"T_0e44f_row8_col0\" class=\"data row8 col0\" >-0.660186</td>\n",
|
1271 |
+
" <td id=\"T_0e44f_row8_col1\" class=\"data row8 col1\" >-0.678450</td>\n",
|
1272 |
+
" <td id=\"T_0e44f_row8_col2\" class=\"data row8 col2\" >-0.673157</td>\n",
|
1273 |
+
" <td id=\"T_0e44f_row8_col3\" class=\"data row8 col3\" >-0.666774</td>\n",
|
1274 |
+
" <td id=\"T_0e44f_row8_col4\" class=\"data row8 col4\" >-0.681109</td>\n",
|
1275 |
+
" <td id=\"T_0e44f_row8_col5\" class=\"data row8 col5\" >-0.671979</td>\n",
|
1276 |
+
" </tr>\n",
|
1277 |
+
" </tbody>\n",
|
1278 |
+
"</table>\n"
|
1279 |
+
],
|
1280 |
+
"text/plain": [
|
1281 |
+
"<pandas.io.formats.style.Styler at 0x21e0aab11e0>"
|
1282 |
+
]
|
1283 |
+
},
|
1284 |
+
"metadata": {},
|
1285 |
+
"output_type": "display_data"
|
1286 |
+
}
|
1287 |
+
],
|
1288 |
+
"source": [
|
1289 |
+
"# XGB hyperparameter that deals with unbalanced\n",
|
1290 |
+
"scale_pos_weight = Y.mean()**-1\n",
|
1291 |
+
"\n",
|
1292 |
+
"# Creating the model objects\n",
|
1293 |
+
"cls_lr = LogisticRegression(\n",
|
1294 |
+
" class_weight=\"balanced\", # Hyperparameter to deal with unbalanced output\n",
|
1295 |
+
" random_state=lucky_num)\n",
|
1296 |
+
"# cls_svm = SVC(random_state=lucky_num) # Remove due its resource consumption and worst results\n",
|
1297 |
+
"cls_nb = GaussianNB()\n",
|
1298 |
+
"cls_knn = KNeighborsClassifier()\n",
|
1299 |
+
"cls_rf = RandomForestClassifier(\n",
|
1300 |
+
" random_state=lucky_num,\n",
|
1301 |
+
" class_weight=\"balanced_subsample\") # Hyperparameter to deal with unbalanced output\n",
|
1302 |
+
"cls_gbc = GradientBoostingClassifier(random_state=lucky_num)\n",
|
1303 |
+
"cls_xgb = xgb.XGBClassifier(\n",
|
1304 |
+
" objective=\"binary:logistic\",\n",
|
1305 |
+
" verbose=None,\n",
|
1306 |
+
" random_state=lucky_num,\n",
|
1307 |
+
" scale_pos_weight = scale_pos_weight)\n",
|
1308 |
+
"\n",
|
1309 |
+
"# Lists to iterate on our modeling function\n",
|
1310 |
+
"cls_name = [\"LR\", \"NB\", \"KNN\", \"RF\", \"GBC\", \"XGB\"]\n",
|
1311 |
+
"cls_list = [cls_lr, cls_NB, cls_knn, cls_rf, cls_gbc, cls_xgb]\n",
|
1312 |
+
"\n",
|
1313 |
+
"mdl_summaries = []\n",
|
1314 |
+
"for name, inst in zip(cls_name, cls_list):\n",
|
1315 |
+
" mdl_list = create_model(name, inst)\n",
|
1316 |
+
" mdl_list = [name] + mdl_list\n",
|
1317 |
+
" mdl_summaries.append(mdl_list)\n",
|
1318 |
+
"\n",
|
1319 |
+
"df_mdl = pd.DataFrame(\n",
|
1320 |
+
" mdl_summaries,\n",
|
1321 |
+
" columns=[\n",
|
1322 |
+
" \"model\",\n",
|
1323 |
+
" \"test_accuracy\",\n",
|
1324 |
+
" \"test_f1\",\n",
|
1325 |
+
" \"test_precision\",\n",
|
1326 |
+
" \"test_recall\",\n",
|
1327 |
+
" \"test_roc_auc\",\n",
|
1328 |
+
" \"test_brier\",\n",
|
1329 |
+
" \"test_log_loss\"])"
|
1330 |
+
]
|
1331 |
+
},
|
1332 |
+
{
|
1333 |
+
"cell_type": "code",
|
1334 |
+
"execution_count": 16,
|
1335 |
+
"metadata": {},
|
1336 |
+
"outputs": [
|
1337 |
+
{
|
1338 |
+
"data": {
|
1339 |
+
"text/html": [
|
1340 |
+
"<style type=\"text/css\">\n",
|
1341 |
+
"</style>\n",
|
1342 |
+
"<table id=\"T_0de30\">\n",
|
1343 |
+
" <caption>Test set validation scores</caption>\n",
|
1344 |
+
" <thead>\n",
|
1345 |
+
" <tr>\n",
|
1346 |
+
" <th class=\"blank level0\" > </th>\n",
|
1347 |
+
" <th id=\"T_0de30_level0_col0\" class=\"col_heading level0 col0\" >model</th>\n",
|
1348 |
+
" <th id=\"T_0de30_level0_col1\" class=\"col_heading level0 col1\" >test_accuracy</th>\n",
|
1349 |
+
" <th id=\"T_0de30_level0_col2\" class=\"col_heading level0 col2\" >test_f1</th>\n",
|
1350 |
+
" <th id=\"T_0de30_level0_col3\" class=\"col_heading level0 col3\" >test_precision</th>\n",
|
1351 |
+
" <th id=\"T_0de30_level0_col4\" class=\"col_heading level0 col4\" >test_recall</th>\n",
|
1352 |
+
" <th id=\"T_0de30_level0_col5\" class=\"col_heading level0 col5\" >test_roc_auc</th>\n",
|
1353 |
+
" <th id=\"T_0de30_level0_col6\" class=\"col_heading level0 col6\" >test_brier</th>\n",
|
1354 |
+
" <th id=\"T_0de30_level0_col7\" class=\"col_heading level0 col7\" >test_log_loss</th>\n",
|
1355 |
+
" </tr>\n",
|
1356 |
+
" </thead>\n",
|
1357 |
+
" <tbody>\n",
|
1358 |
+
" <tr>\n",
|
1359 |
+
" <th id=\"T_0de30_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
|
1360 |
+
" <td id=\"T_0de30_row0_col0\" class=\"data row0 col0\" >LR</td>\n",
|
1361 |
+
" <td id=\"T_0de30_row0_col1\" class=\"data row0 col1\" >0.660523</td>\n",
|
1362 |
+
" <td id=\"T_0de30_row0_col2\" class=\"data row0 col2\" >0.414431</td>\n",
|
1363 |
+
" <td id=\"T_0de30_row0_col3\" class=\"data row0 col3\" >0.331889</td>\n",
|
1364 |
+
" <td id=\"T_0de30_row0_col4\" class=\"data row0 col4\" >0.551621</td>\n",
|
1365 |
+
" <td id=\"T_0de30_row0_col5\" class=\"data row0 col5\" >0.667540</td>\n",
|
1366 |
+
" <td id=\"T_0de30_row0_col6\" class=\"data row0 col6\" >0.225709</td>\n",
|
1367 |
+
" <td id=\"T_0de30_row0_col7\" class=\"data row0 col7\" >0.644816</td>\n",
|
1368 |
+
" </tr>\n",
|
1369 |
+
" <tr>\n",
|
1370 |
+
" <th id=\"T_0de30_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
|
1371 |
+
" <td id=\"T_0de30_row1_col0\" class=\"data row1 col0\" >XGB</td>\n",
|
1372 |
+
" <td id=\"T_0de30_row1_col1\" class=\"data row1 col1\" >0.619216</td>\n",
|
1373 |
+
" <td id=\"T_0de30_row1_col2\" class=\"data row1 col2\" >0.387639</td>\n",
|
1374 |
+
" <td id=\"T_0de30_row1_col3\" class=\"data row1 col3\" >0.298285</td>\n",
|
1375 |
+
" <td id=\"T_0de30_row1_col4\" class=\"data row1 col4\" >0.553421</td>\n",
|
1376 |
+
" <td id=\"T_0de30_row1_col5\" class=\"data row1 col5\" >0.630028</td>\n",
|
1377 |
+
" <td id=\"T_0de30_row1_col6\" class=\"data row1 col6\" >0.236789</td>\n",
|
1378 |
+
" <td id=\"T_0de30_row1_col7\" class=\"data row1 col7\" >0.671979</td>\n",
|
1379 |
+
" </tr>\n",
|
1380 |
+
" <tr>\n",
|
1381 |
+
" <th id=\"T_0de30_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
|
1382 |
+
" <td id=\"T_0de30_row2_col0\" class=\"data row2 col0\" >NB</td>\n",
|
1383 |
+
" <td id=\"T_0de30_row2_col1\" class=\"data row2 col1\" >0.699608</td>\n",
|
1384 |
+
" <td id=\"T_0de30_row2_col2\" class=\"data row2 col2\" >0.381260</td>\n",
|
1385 |
+
" <td id=\"T_0de30_row2_col3\" class=\"data row2 col3\" >0.345703</td>\n",
|
1386 |
+
" <td id=\"T_0de30_row2_col4\" class=\"data row2 col4\" >0.424970</td>\n",
|
1387 |
+
" <td id=\"T_0de30_row2_col5\" class=\"data row2 col5\" >0.634031</td>\n",
|
1388 |
+
" <td id=\"T_0de30_row2_col6\" class=\"data row2 col6\" >0.268468</td>\n",
|
1389 |
+
" <td id=\"T_0de30_row2_col7\" class=\"data row2 col7\" >1.659029</td>\n",
|
1390 |
+
" </tr>\n",
|
1391 |
+
" <tr>\n",
|
1392 |
+
" <th id=\"T_0de30_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
|
1393 |
+
" <td id=\"T_0de30_row3_col0\" class=\"data row3 col0\" >RF</td>\n",
|
1394 |
+
" <td id=\"T_0de30_row3_col1\" class=\"data row3 col1\" >0.733595</td>\n",
|
1395 |
+
" <td id=\"T_0de30_row3_col2\" class=\"data row3 col2\" >0.242379</td>\n",
|
1396 |
+
" <td id=\"T_0de30_row3_col3\" class=\"data row3 col3\" >0.318359</td>\n",
|
1397 |
+
" <td id=\"T_0de30_row3_col4\" class=\"data row3 col4\" >0.195678</td>\n",
|
1398 |
+
" <td id=\"T_0de30_row3_col5\" class=\"data row3 col5\" >0.601436</td>\n",
|
1399 |
+
" <td id=\"T_0de30_row3_col6\" class=\"data row3 col6\" >0.188506</td>\n",
|
1400 |
+
" <td id=\"T_0de30_row3_col7\" class=\"data row3 col7\" >0.743245</td>\n",
|
1401 |
+
" </tr>\n",
|
1402 |
+
" <tr>\n",
|
1403 |
+
" <th id=\"T_0de30_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
|
1404 |
+
" <td id=\"T_0de30_row4_col0\" class=\"data row4 col0\" >KNN</td>\n",
|
1405 |
+
" <td id=\"T_0de30_row4_col1\" class=\"data row4 col1\" >0.754379</td>\n",
|
1406 |
+
" <td id=\"T_0de30_row4_col2\" class=\"data row4 col2\" >0.214136</td>\n",
|
1407 |
+
" <td id=\"T_0de30_row4_col3\" class=\"data row4 col3\" >0.353103</td>\n",
|
1408 |
+
" <td id=\"T_0de30_row4_col4\" class=\"data row4 col4\" >0.153661</td>\n",
|
1409 |
+
" <td id=\"T_0de30_row4_col5\" class=\"data row4 col5\" >0.582555</td>\n",
|
1410 |
+
" <td id=\"T_0de30_row4_col6\" class=\"data row4 col6\" >0.189077</td>\n",
|
1411 |
+
" <td id=\"T_0de30_row4_col7\" class=\"data row4 col7\" >2.428060</td>\n",
|
1412 |
+
" </tr>\n",
|
1413 |
+
" <tr>\n",
|
1414 |
+
" <th id=\"T_0de30_level0_row5\" class=\"row_heading level0 row5\" >5</th>\n",
|
1415 |
+
" <td id=\"T_0de30_row5_col0\" class=\"data row5 col0\" >GBC</td>\n",
|
1416 |
+
" <td id=\"T_0de30_row5_col1\" class=\"data row5 col1\" >0.781830</td>\n",
|
1417 |
+
" <td id=\"T_0de30_row5_col2\" class=\"data row5 col2\" >0.085479</td>\n",
|
1418 |
+
" <td id=\"T_0de30_row5_col3\" class=\"data row5 col3\" >0.490566</td>\n",
|
1419 |
+
" <td id=\"T_0de30_row5_col4\" class=\"data row5 col4\" >0.046819</td>\n",
|
1420 |
+
" <td id=\"T_0de30_row5_col5\" class=\"data row5 col5\" >0.669660</td>\n",
|
1421 |
+
" <td id=\"T_0de30_row5_col6\" class=\"data row5 col6\" >0.159691</td>\n",
|
1422 |
+
" <td id=\"T_0de30_row5_col7\" class=\"data row5 col7\" >0.493648</td>\n",
|
1423 |
+
" </tr>\n",
|
1424 |
+
" </tbody>\n",
|
1425 |
+
"</table>\n"
|
1426 |
+
],
|
1427 |
+
"text/plain": [
|
1428 |
+
"<pandas.io.formats.style.Styler at 0x21e044360e0>"
|
1429 |
+
]
|
1430 |
+
},
|
1431 |
+
"metadata": {},
|
1432 |
+
"output_type": "display_data"
|
1433 |
+
}
|
1434 |
+
],
|
1435 |
+
"source": [
|
1436 |
+
"df_mdl.sort_values(\n",
|
1437 |
+
" \"test_f1\",\n",
|
1438 |
+
" ascending=False,\n",
|
1439 |
+
" inplace=True,\n",
|
1440 |
+
" ignore_index=True)\n",
|
1441 |
+
"\n",
|
1442 |
+
"display(df_mdl.style.set_caption(\"Test set validation scores\"))"
|
1443 |
+
]
|
1444 |
+
},
|
1445 |
+
{
|
1446 |
+
"cell_type": "markdown",
|
1447 |
+
"metadata": {},
|
1448 |
+
"source": [
|
1449 |
+
"Any of models present good results! We will try to fit a composite model with the 3 better."
|
1450 |
+
]
|
1451 |
+
},
|
1452 |
+
{
|
1453 |
+
"cell_type": "code",
|
1454 |
+
"execution_count": 21,
|
1455 |
+
"metadata": {},
|
1456 |
+
"outputs": [
|
1457 |
+
{
|
1458 |
+
"data": {
|
1459 |
+
"text/html": [
|
1460 |
+
"<style type=\"text/css\">\n",
|
1461 |
+
"</style>\n",
|
1462 |
+
"<table id=\"T_ca02f\">\n",
|
1463 |
+
" <caption>Test set validation scores for Composite Model</caption>\n",
|
1464 |
+
" <thead>\n",
|
1465 |
+
" <tr>\n",
|
1466 |
+
" <th class=\"blank level0\" > </th>\n",
|
1467 |
+
" <th id=\"T_ca02f_level0_col0\" class=\"col_heading level0 col0\" >CV1</th>\n",
|
1468 |
+
" <th id=\"T_ca02f_level0_col1\" class=\"col_heading level0 col1\" >CV2</th>\n",
|
1469 |
+
" <th id=\"T_ca02f_level0_col2\" class=\"col_heading level0 col2\" >CV3</th>\n",
|
1470 |
+
" <th id=\"T_ca02f_level0_col3\" class=\"col_heading level0 col3\" >CV4</th>\n",
|
1471 |
+
" <th id=\"T_ca02f_level0_col4\" class=\"col_heading level0 col4\" >CV5</th>\n",
|
1472 |
+
" <th id=\"T_ca02f_level0_col5\" class=\"col_heading level0 col5\" >TestSet</th>\n",
|
1473 |
+
" </tr>\n",
|
1474 |
+
" </thead>\n",
|
1475 |
+
" <tbody>\n",
|
1476 |
+
" <tr>\n",
|
1477 |
+
" <th id=\"T_ca02f_level0_row0\" class=\"row_heading level0 row0\" >fit_time</th>\n",
|
1478 |
+
" <td id=\"T_ca02f_row0_col0\" class=\"data row0 col0\" >0.622165</td>\n",
|
1479 |
+
" <td id=\"T_ca02f_row0_col1\" class=\"data row0 col1\" >0.732543</td>\n",
|
1480 |
+
" <td id=\"T_ca02f_row0_col2\" class=\"data row0 col2\" >0.591849</td>\n",
|
1481 |
+
" <td id=\"T_ca02f_row0_col3\" class=\"data row0 col3\" >0.699149</td>\n",
|
1482 |
+
" <td id=\"T_ca02f_row0_col4\" class=\"data row0 col4\" >0.617794</td>\n",
|
1483 |
+
" <td id=\"T_ca02f_row0_col5\" class=\"data row0 col5\" >nan</td>\n",
|
1484 |
+
" </tr>\n",
|
1485 |
+
" <tr>\n",
|
1486 |
+
" <th id=\"T_ca02f_level0_row1\" class=\"row_heading level0 row1\" >score_time</th>\n",
|
1487 |
+
" <td id=\"T_ca02f_row1_col0\" class=\"data row1 col0\" >0.023785</td>\n",
|
1488 |
+
" <td id=\"T_ca02f_row1_col1\" class=\"data row1 col1\" >0.027777</td>\n",
|
1489 |
+
" <td id=\"T_ca02f_row1_col2\" class=\"data row1 col2\" >0.030991</td>\n",
|
1490 |
+
" <td id=\"T_ca02f_row1_col3\" class=\"data row1 col3\" >0.027807</td>\n",
|
1491 |
+
" <td id=\"T_ca02f_row1_col4\" class=\"data row1 col4\" >0.023930</td>\n",
|
1492 |
+
" <td id=\"T_ca02f_row1_col5\" class=\"data row1 col5\" >nan</td>\n",
|
1493 |
+
" </tr>\n",
|
1494 |
+
" <tr>\n",
|
1495 |
+
" <th id=\"T_ca02f_level0_row2\" class=\"row_heading level0 row2\" >test_accuracy</th>\n",
|
1496 |
+
" <td id=\"T_ca02f_row2_col0\" class=\"data row2 col0\" >0.714846</td>\n",
|
1497 |
+
" <td id=\"T_ca02f_row2_col1\" class=\"data row2 col1\" >0.701120</td>\n",
|
1498 |
+
" <td id=\"T_ca02f_row2_col2\" class=\"data row2 col2\" >0.695713</td>\n",
|
1499 |
+
" <td id=\"T_ca02f_row2_col3\" class=\"data row2 col3\" >0.693752</td>\n",
|
1500 |
+
" <td id=\"T_ca02f_row2_col4\" class=\"data row2 col4\" >0.677221</td>\n",
|
1501 |
+
" <td id=\"T_ca02f_row2_col5\" class=\"data row2 col5\" >0.699346</td>\n",
|
1502 |
+
" </tr>\n",
|
1503 |
+
" <tr>\n",
|
1504 |
+
" <th id=\"T_ca02f_level0_row3\" class=\"row_heading level0 row3\" >test_f1</th>\n",
|
1505 |
+
" <td id=\"T_ca02f_row3_col0\" class=\"data row3 col0\" >0.412240</td>\n",
|
1506 |
+
" <td id=\"T_ca02f_row3_col1\" class=\"data row3 col1\" >0.404243</td>\n",
|
1507 |
+
" <td id=\"T_ca02f_row3_col2\" class=\"data row3 col2\" >0.389201</td>\n",
|
1508 |
+
" <td id=\"T_ca02f_row3_col3\" class=\"data row3 col3\" >0.385610</td>\n",
|
1509 |
+
" <td id=\"T_ca02f_row3_col4\" class=\"data row3 col4\" >0.369803</td>\n",
|
1510 |
+
" <td id=\"T_ca02f_row3_col5\" class=\"data row3 col5\" >0.389597</td>\n",
|
1511 |
+
" </tr>\n",
|
1512 |
+
" <tr>\n",
|
1513 |
+
" <th id=\"T_ca02f_level0_row4\" class=\"row_heading level0 row4\" >test_precision</th>\n",
|
1514 |
+
" <td id=\"T_ca02f_row4_col0\" class=\"data row4 col0\" >0.374214</td>\n",
|
1515 |
+
" <td id=\"T_ca02f_row4_col1\" class=\"data row4 col1\" >0.357354</td>\n",
|
1516 |
+
" <td id=\"T_ca02f_row4_col2\" class=\"data row4 col2\" >0.345654</td>\n",
|
1517 |
+
" <td id=\"T_ca02f_row4_col3\" class=\"data row4 col3\" >0.342315</td>\n",
|
1518 |
+
" <td id=\"T_ca02f_row4_col4\" class=\"data row4 col4\" >0.321905</td>\n",
|
1519 |
+
" <td id=\"T_ca02f_row4_col5\" class=\"data row4 col5\" >0.349191</td>\n",
|
1520 |
+
" </tr>\n",
|
1521 |
+
" <tr>\n",
|
1522 |
+
" <th id=\"T_ca02f_level0_row5\" class=\"row_heading level0 row5\" >test_recall</th>\n",
|
1523 |
+
" <td id=\"T_ca02f_row5_col0\" class=\"data row5 col0\" >0.458869</td>\n",
|
1524 |
+
" <td id=\"T_ca02f_row5_col1\" class=\"data row5 col1\" >0.465296</td>\n",
|
1525 |
+
" <td id=\"T_ca02f_row5_col2\" class=\"data row5 col2\" >0.445302</td>\n",
|
1526 |
+
" <td id=\"T_ca02f_row5_col3\" class=\"data row5 col3\" >0.441441</td>\n",
|
1527 |
+
" <td id=\"T_ca02f_row5_col4\" class=\"data row5 col4\" >0.434447</td>\n",
|
1528 |
+
" <td id=\"T_ca02f_row5_col5\" class=\"data row5 col5\" >0.440576</td>\n",
|
1529 |
+
" </tr>\n",
|
1530 |
+
" <tr>\n",
|
1531 |
+
" <th id=\"T_ca02f_level0_row6\" class=\"row_heading level0 row6\" >test_roc_auc</th>\n",
|
1532 |
+
" <td id=\"T_ca02f_row6_col0\" class=\"data row6 col0\" >0.679117</td>\n",
|
1533 |
+
" <td id=\"T_ca02f_row6_col1\" class=\"data row6 col1\" >0.662225</td>\n",
|
1534 |
+
" <td id=\"T_ca02f_row6_col2\" class=\"data row6 col2\" >0.651689</td>\n",
|
1535 |
+
" <td id=\"T_ca02f_row6_col3\" class=\"data row6 col3\" >0.658722</td>\n",
|
1536 |
+
" <td id=\"T_ca02f_row6_col4\" class=\"data row6 col4\" >0.640881</td>\n",
|
1537 |
+
" <td id=\"T_ca02f_row6_col5\" class=\"data row6 col5\" >0.658847</td>\n",
|
1538 |
+
" </tr>\n",
|
1539 |
+
" <tr>\n",
|
1540 |
+
" <th id=\"T_ca02f_level0_row7\" class=\"row_heading level0 row7\" >test_neg_brier_score</th>\n",
|
1541 |
+
" <td id=\"T_ca02f_row7_col0\" class=\"data row7 col0\" >-0.199904</td>\n",
|
1542 |
+
" <td id=\"T_ca02f_row7_col1\" class=\"data row7 col1\" >-0.208466</td>\n",
|
1543 |
+
" <td id=\"T_ca02f_row7_col2\" class=\"data row7 col2\" >-0.211428</td>\n",
|
1544 |
+
" <td id=\"T_ca02f_row7_col3\" class=\"data row7 col3\" >-0.209929</td>\n",
|
1545 |
+
" <td id=\"T_ca02f_row7_col4\" class=\"data row7 col4\" >-0.218236</td>\n",
|
1546 |
+
" <td id=\"T_ca02f_row7_col5\" class=\"data row7 col5\" >-0.208991</td>\n",
|
1547 |
+
" </tr>\n",
|
1548 |
+
" <tr>\n",
|
1549 |
+
" <th id=\"T_ca02f_level0_row8\" class=\"row_heading level0 row8\" >test_neg_log_loss</th>\n",
|
1550 |
+
" <td id=\"T_ca02f_row8_col0\" class=\"data row8 col0\" >-0.590700</td>\n",
|
1551 |
+
" <td id=\"T_ca02f_row8_col1\" class=\"data row8 col1\" >-0.611156</td>\n",
|
1552 |
+
" <td id=\"T_ca02f_row8_col2\" class=\"data row8 col2\" >-0.616876</td>\n",
|
1553 |
+
" <td id=\"T_ca02f_row8_col3\" class=\"data row8 col3\" >-0.613285</td>\n",
|
1554 |
+
" <td id=\"T_ca02f_row8_col4\" class=\"data row8 col4\" >-0.633314</td>\n",
|
1555 |
+
" <td id=\"T_ca02f_row8_col5\" class=\"data row8 col5\" >-0.611563</td>\n",
|
1556 |
+
" </tr>\n",
|
1557 |
+
" </tbody>\n",
|
1558 |
+
"</table>\n"
|
1559 |
+
],
|
1560 |
+
"text/plain": [
|
1561 |
+
"<pandas.io.formats.style.Styler at 0x21e04333100>"
|
1562 |
+
]
|
1563 |
+
},
|
1564 |
+
"metadata": {},
|
1565 |
+
"output_type": "display_data"
|
1566 |
+
}
|
1567 |
+
],
|
1568 |
+
"source": [
|
1569 |
+
"# Selecting the models\n",
|
1570 |
+
"cls_name = [\"LR\", \"NB\", \"XGB\"]\n",
|
1571 |
+
"cls_list = [cls_lr, cls_nb, cls_xgb]\n",
|
1572 |
+
"\n",
|
1573 |
+
"# Training the voting classifier\n",
|
1574 |
+
"cls_vot = VotingClassifier([*zip(cls_name, cls_list)], voting=\"soft\")\n",
|
1575 |
+
"cls_vot.fit(X_train, y_train)\n",
|
1576 |
+
"\n",
|
1577 |
+
"# Using cross-validation to evaluate the model fitted\n",
|
1578 |
+
"cls_cross = cross_validate(\n",
|
1579 |
+
" estimator=cls_vot,\n",
|
1580 |
+
" X=X_train,\n",
|
1581 |
+
" y=y_train,\n",
|
1582 |
+
" cv=5,\n",
|
1583 |
+
" scoring=scores)\n",
|
1584 |
+
"\n",
|
1585 |
+
"df_vot = pd.DataFrame.from_dict(cls_cross, orient='index', columns=[\"CV\"+str(i) for i in range(1,6)])\n",
|
1586 |
+
"\n",
|
1587 |
+
"# Calculating score to test set\n",
|
1588 |
+
"accurancy, f1, precision, recall, roc_auc, brier_score, log_loss_value = eval_model(cls_vot)\n",
|
1589 |
+
"\n",
|
1590 |
+
"# Filling a dataframe to better presentation\n",
|
1591 |
+
"df_vot.at[\"test_accuracy\", \"TestSet\"] = accurancy\n",
|
1592 |
+
"df_vot.at[\"test_f1\", \"TestSet\"] = f1\n",
|
1593 |
+
"df_vot.at[\"test_recall\", \"TestSet\"] = recall\n",
|
1594 |
+
"df_vot.at[\"test_precision\", \"TestSet\"] = precision\n",
|
1595 |
+
"df_vot.at[\"test_roc_auc\", \"TestSet\"] = roc_auc\n",
|
1596 |
+
"df_vot.at[\"test_neg_brier_score\", \"TestSet\"] = -brier_score\n",
|
1597 |
+
"df_vot.at[\"test_neg_log_loss\", \"TestSet\"] = -log_loss_value\n",
|
1598 |
+
"\n",
|
1599 |
+
"display(df_vot.style.set_caption(\"Test set validation scores for Composite Model\"))"
|
1600 |
+
]
|
1601 |
+
},
|
1602 |
+
{
|
1603 |
+
"cell_type": "markdown",
|
1604 |
+
"metadata": {},
|
1605 |
+
"source": [
|
1606 |
+
"The composite model is not better than neat models. Well, maybe some tuning could handle this. But this will be done in future work."
|
1607 |
+
]
|
1608 |
+
},
|
1609 |
+
{
|
1610 |
+
"cell_type": "code",
|
1611 |
+
"execution_count": 11,
|
1612 |
+
"metadata": {},
|
1613 |
+
"outputs": [
|
1614 |
+
{
|
1615 |
+
"data": {
|
1616 |
+
"text/plain": [
|
1617 |
+
"['c:\\\\Users\\\\grego\\\\OneDrive\\\\Documentos\\\\Documentos Pessoais\\\\00_DataCamp\\\\09_VSC\\\\poa_car_accidents\\\\poa_car_accidents\\\\model\\\\model_feridos.pkl']"
|
1618 |
+
]
|
1619 |
+
},
|
1620 |
+
"execution_count": 11,
|
1621 |
+
"metadata": {},
|
1622 |
+
"output_type": "execute_result"
|
1623 |
+
}
|
1624 |
+
],
|
1625 |
+
"source": [
|
1626 |
+
"# Saving\n",
|
1627 |
+
"# file_name = \"model_\" + output + '.pkl'\n",
|
1628 |
+
"# jb.dump(cls_vot, path.join(path.abspath(\"./\"), file_name))"
|
1629 |
+
]
|
1630 |
+
}
|
1631 |
+
],
|
1632 |
+
"metadata": {
|
1633 |
+
"kernelspec": {
|
1634 |
+
"display_name": "Python 3.10.6 64-bit",
|
1635 |
+
"language": "python",
|
1636 |
+
"name": "python3"
|
1637 |
+
},
|
1638 |
+
"language_info": {
|
1639 |
+
"codemirror_mode": {
|
1640 |
+
"name": "ipython",
|
1641 |
+
"version": 3
|
1642 |
+
},
|
1643 |
+
"file_extension": ".py",
|
1644 |
+
"mimetype": "text/x-python",
|
1645 |
+
"name": "python",
|
1646 |
+
"nbconvert_exporter": "python",
|
1647 |
+
"pygments_lexer": "ipython3",
|
1648 |
+
"version": "3.10.6"
|
1649 |
+
},
|
1650 |
+
"orig_nbformat": 4,
|
1651 |
+
"vscode": {
|
1652 |
+
"interpreter": {
|
1653 |
+
"hash": "1372d04dbd71fdc5436c5d6e671c1b9287e750e86143c81b5a7ba0acaf653c5e"
|
1654 |
+
}
|
1655 |
+
}
|
1656 |
+
},
|
1657 |
+
"nbformat": 4,
|
1658 |
+
"nbformat_minor": 2
|
1659 |
+
}
|
model/scaler_feridos.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cf819a52c70bb4058de5851aedbd92da179246c132e87593420f87850f7ae1b0
|
3 |
+
size 2387
|
model/scaler_feridos_gr.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:063891bf93f83929da6294c645de8d503fcec2c756cde601933800c7d3dc6598
|
3 |
+
size 2387
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
joblib
|
2 |
+
numpy
|
3 |
+
pandas
|
4 |
+
gradio
|
5 |
+
scikit-learn
|
6 |
+
xgboost=1.6.2
|