aaronbi commited on
Commit
5c2f56c
1 Parent(s): c84c8e9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.preprocessing import *
2
+ from sklearn.model_selection import *
3
+ import numpy as np
4
+ from sklearn.linear_model import LogisticRegression
5
+ from sklearn.preprocessing import LabelEncoder
6
+ from keras.models import Sequential
7
+ from keras.layers import Dense
8
+ from sklearn.metrics import *
9
+ from sklearn.svm import SVC
10
+ import pandas as pd
11
+
12
+ import gradio as gr
13
+ from joblib import dump, load
14
+
15
+ #loading models
16
+ log = load('logistic_model.joblib')
17
+ knn = load('knn_model.joblib')
18
+ decision = load('decision_tree_model.joblib')
19
+ deep = load('deep_model.joblib')
20
+
21
+ #input/output modules
22
+ input_module1 = gr.Dropdown(choices=["Logistic Regression", "KNN", "Decision Tree","Neural Network"], label = "method")
23
+ input_module2 = gr.Dropdown(choices=["male", 'female'], label = "gender")
24
+ input_module3 = gr.Dropdown(choices=['african-american','asian','black','black african','black or african american',\
25
+ 'caucasian', 'han chinese', 'hispanic', 'intermediate', 'japanese', 'korean', 'other',\
26
+ 'Other (Black British)','Other Mixed Race','White'], label = "race")
27
+ input_module4 = gr.Number(label='age')
28
+ input_module5 = gr.Number(label='height (cm)')
29
+ input_module6 = gr.Number(label='weight (kg)')
30
+ input_module7 = gr.Checkbox(label='Diabetes')
31
+ input_module8 = gr.Checkbox(label='Simvastatin (Zocor)')
32
+ input_module9 = gr.Checkbox(label='Amiodarone (Cordarone)')
33
+ input_module10 = gr.Number(label='target INR')
34
+ input_module11 = gr.Number(label='INR on Reported Therapeutic Dose of Warfarin')
35
+ input_module12 = gr.Number(label='Cyp2C9 genotypes (1-13)')
36
+ input_module13 = gr.Dropdown(choices=['A/A','A/G','G/G'], label = "VKORC1 genotype")
37
+
38
+ output_module = gr.Number(label='Therapeutic Dose of Warfarin (>=30 mg/wk) (1=true, 0=false)')
39
+
40
+ race_options = ['african-american','asian','black','black african','black or african american',\
41
+ 'caucasian', 'han chinese', 'hispanic', 'intermediate', 'japanese', 'korean', 'other',\
42
+ 'Other (Black British)','Other Mixed Race','White']
43
+
44
+ #gradio function
45
+ def predict(method, gender, race, age, height, weight, diabetes, simv, amio, targetINR, INR, cyp2c9, vkorc1):
46
+
47
+ #converting inputs into numeric data
48
+ if gender == 'male':
49
+ gender = 0
50
+ else:
51
+ gender = 1
52
+
53
+ for i in range(len(race_options)):
54
+ if race_options[i] == race:
55
+ race_options[i] = 1
56
+ else:
57
+ race_options[i] = 0
58
+
59
+ if diabetes == True:
60
+ diabetes = 1
61
+ else:
62
+ diabetes = 0
63
+
64
+ if simv == True:
65
+ simv = 1
66
+ else:
67
+ simv = 0
68
+
69
+ if amio == True:
70
+ amio = 1
71
+ else:
72
+ amio = 0
73
+
74
+ if vkorc1 == 'A/A':
75
+ vkorc1 = 0
76
+ elif vkorc1 == 'A/G':
77
+ vkorc1 = 1
78
+ else:
79
+ vkorc1 = 2
80
+
81
+ #compiling data
82
+ data = [gender, age, height, weight, diabetes, simv, amio, targetINR, INR, cyp2c9, vkorc1]
83
+ data.extend(race_options)
84
+ data.extend([0]) #accounting for extra unused race category (there are 2 'other' options for race)
85
+
86
+ #predicting using given method
87
+ if method == "Logistic Regression":
88
+ value = log.predict([data])[0]
89
+ elif method == "KNN":
90
+ value = knn.predict([data])[0]
91
+ elif method == "Decision Tree":
92
+ value = decision.predict([data])[0]
93
+ else:
94
+ value = deep.predict([data])[0]
95
+ if value >= 0.5:
96
+ value = 1
97
+ else:
98
+ value = 0
99
+
100
+
101
+ return value
102
+
103
+
104
+
105
+ gr.Interface(fn=predict, inputs=[input_module1,input_module2,input_module3,input_module4,input_module5,input_module6,input_module7,\
106
+ input_module8,input_module9,input_module10,input_module11,input_module12,input_module13], outputs=output_module).launch(debug=True)