oort77 commited on
Commit
3ccb038
1 Parent(s): 10781d4

initial commit

Browse files
Files changed (4) hide show
  1. app.py +238 -0
  2. car_at_night.ico +0 -0
  3. requirements.txt +0 -0
  4. sky.png +0 -0
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File: app.py
3
+ # Project: 'Homework #2 OTUS.ML.Advanced'
4
+ # Created by Gennady Matveev (gm@og.ly) on 02-01-2022.
5
+
6
+ # Import libraries
7
+ import os
8
+ import pandas as pd
9
+ import streamlit as st
10
+ import requests
11
+
12
+ st.set_page_config(page_title='OTUS.ML.ADV_HW2', page_icon='./sky.ico', layout='centered', initial_sidebar_state='expanded')
13
+
14
+ padding = 0
15
+ st.markdown(f""" <style>
16
+ .reportview-container .main .block-container{{
17
+ padding-top: {padding}rem;
18
+ padding-right: {padding}rem;
19
+ padding-left: {padding}rem;
20
+ padding-bottom: {padding}rem;
21
+ }} </style> """, unsafe_allow_html=True)
22
+
23
+ st.image('./sky.png')
24
+ st.subheader('Homework #2 OTUS.ML.Advanced')
25
+ st.write('Classification model for Heart Disease UCI: &nbsp;&nbsp;https://www.kaggle.com/ronitf/heart-disease-uci')
26
+ st.markdown("""---""")
27
+
28
+ # Import data, will need it for get requests
29
+ @st.cache(ttl=600)
30
+ def get_data():
31
+ url = 'https://drive.google.com/uc?export=download&id=1wY3r2MwQoa-jiyzRoEM_eF_EU11vrCs0'
32
+ return pd.read_csv(url, compression='zip')
33
+
34
+ df = get_data()
35
+
36
+ # Main interface
37
+ row_num = st.number_input('Please choose features vector 0-302 or set values in the left sidebar',
38
+ min_value=0, max_value=302, value=42)
39
+ x17 =df.iloc[row_num,:-1].to_frame().T
40
+ st.write('Features, X')
41
+ st.write(x17)
42
+
43
+ # START Sidebar ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
44
+
45
+ with st.sidebar.expander("I want to choose my values", expanded=False):
46
+ age = st.number_input('Age', min_value=25, max_value=80, value=57)
47
+ sex = st.number_input('Sex', min_value=0, max_value=1, value=1)
48
+ cp = st.number_input('cp', min_value=0, max_value=4, value=0)
49
+ trestbps = st.number_input('trestbps', min_value=90, max_value=200, value=125)
50
+ chol = st.number_input('chol', min_value=125, max_value=550, value=240)
51
+ fbs = st.number_input('fbs', min_value=0, max_value=1, value=0)
52
+ restecg = st.number_input('restecg', min_value=0, max_value=2, value=1)
53
+ thalach = st.number_input('thalach', min_value=70, max_value=200, value=160)
54
+ exang = st.number_input('exang', min_value=0, max_value=1, value=0)
55
+ oldpeak = st.number_input('oldpeak', min_value=0, max_value=6, value=2)
56
+ slope = st.number_input('slope', min_value=0, max_value=2, value=2)
57
+ ca = st.number_input('ca', min_value=0, max_value=4, value=0)
58
+ thal = st.number_input('thal', min_value=0, max_value=3, value=2)
59
+
60
+ features = age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal
61
+ send_req_sidebar = st.button('Get prediction')
62
+
63
+ # END Sidebar ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
64
+
65
+ send_req = st.button('Send get request')
66
+
67
+ backend_address = "https://hw2backend.herokuapp.com/predict/"
68
+
69
+ # Main page button
70
+ if send_req:
71
+ prediction = requests.get(backend_address,
72
+ params={"q": tuple(x17.values)})
73
+ st.code(f'Parameters sent: {x17.values}')
74
+ col1, col2 = st.columns(2)
75
+ with col1:
76
+ st.write('Model predicts')
77
+ st.success(f'y = {prediction.text}')
78
+ with col2:
79
+ st.write('Ground truth')
80
+ if int(prediction.text) == int(df.iloc[row_num]["target"]):
81
+ st.success(f'y = {int(df.iloc[row_num]["target"])}')
82
+ else:
83
+ st.warning(f'y = {int(df.iloc[row_num]["target"])}')
84
+
85
+ # Sidebar button
86
+
87
+ if send_req_sidebar:
88
+ prediction = requests.get(backend_address,
89
+ params={"q": features})
90
+ st.code(f'Parameters sent: {features}')
91
+ st.write('Model predicts')
92
+ st.info(f'y = {prediction.text}')
93
+
94
+ # Show this code
95
+ with st.expander("Show code", expanded=False):
96
+ show_me = st.checkbox('Show code of this program')
97
+ if show_me:
98
+ st.code("""
99
+ # -*- coding: utf-8 -*-
100
+ # File: app.py
101
+ # Project: 'Homework #2 OTUS.ML.Advanced'
102
+ # Created by Gennady Matveev (gm@og.ly) on 02-01-2022.
103
+
104
+ # Import libraries
105
+ import pandas as pd
106
+ import streamlit as st
107
+ import requests
108
+
109
+ st.set_page_config(page_title='OTUS.ML.ADV_HW2', page_icon='./car_at_night.ico',
110
+ layout='centered', initial_sidebar_state='expanded')
111
+
112
+ padding = 0
113
+ st.markdown(f''' <style>
114
+ .reportview-container .main .block-container{{
115
+ padding-top: {padding}rem;
116
+ padding-right: {padding}rem;
117
+ padding-left: {padding}rem;
118
+ padding-bottom: {padding}rem;
119
+ }} </style> ''', unsafe_allow_html=True)
120
+
121
+ st.image('./sky.png')
122
+ st.subheader('Homework #2 OTUS.ML.Advanced')
123
+ st.write('Classification model for Heart Disease UCI: &nbsp;&nbsp;https://www.kaggle.com/ronitf/heart-disease-uci')
124
+ st.markdown('''---''')
125
+
126
+ # Import data, will need it for get requests
127
+ @st.cache(ttl=600)
128
+ def get_data():
129
+ url = 'https://drive.google.com/uc?export=download&id=1wY3r2MwQoa-jiyzRoEM_eF_EU11vrCs0'
130
+ return pd.read_csv(url, compression='zip')
131
+
132
+ df = get_data()
133
+
134
+ # Main interface
135
+ row_num = st.number_input('Please choose features vector 0-302 or set values in the left sidebar',
136
+ min_value=0, max_value=302, value=42)
137
+ x17 =df.iloc[row_num,:-1].to_frame().T
138
+ st.write('Features, X')
139
+ st.write(x17)
140
+
141
+ # START Sidebar ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
142
+
143
+ with st.sidebar.expander("I want to choose my values", expanded=False):
144
+ age = st.number_input('Age', min_value=25, max_value=80, value=57)
145
+ sex = st.number_input('Sex', min_value=0, max_value=1, value=1)
146
+ cp = st.number_input('cp', min_value=0, max_value=4, value=0)
147
+ trestbps = st.number_input('trestbps', min_value=90, max_value=200, value=125)
148
+ chol = st.number_input('chol', min_value=125, max_value=550, value=240)
149
+ fbs = st.number_input('fbs', min_value=0, max_value=1, value=0)
150
+ restecg = st.number_input('restecg', min_value=0, max_value=2, value=1)
151
+ thalach = st.number_input('thalach', min_value=70, max_value=200, value=160)
152
+ exang = st.number_input('exang', min_value=0, max_value=1, value=0)
153
+ oldpeak = st.number_input('oldpeak', min_value=0, max_value=6, value=2)
154
+ slope = st.number_input('slope', min_value=0, max_value=2, value=2)
155
+ ca = st.number_input('ca', min_value=0, max_value=4, value=0)
156
+ thal = st.number_input('thal', min_value=0, max_value=3, value=2)
157
+
158
+ features = age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal
159
+ send_req_sidebar = st.button('Get prediction')
160
+
161
+ # END Sidebar ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
162
+
163
+ ssend_req = st.button('Send get request')
164
+
165
+ backend_address = "https://hw2backend.herokuapp.com/predict/"
166
+
167
+ # Main page button
168
+ if send_req:
169
+ prediction = requests.get(backend_address,
170
+ params={"q": tuple(x17.values)})
171
+ st.code(f'Parameters sent: {x17.values}')
172
+ col1, col2 = st.columns(2)
173
+ with col1:
174
+ st.write('Model predicts')
175
+ st.success(f'y = {prediction.text}')
176
+ with col2:
177
+ st.write('Ground truth')
178
+ if int(prediction.text) == int(df.iloc[row_num]["target"]):
179
+ st.success(f'y = {int(df.iloc[row_num]["target"])}')
180
+ else:
181
+ st.warning(f'y = {int(df.iloc[row_num]["target"])}')
182
+
183
+ # Sidebar button
184
+
185
+ if send_req_sidebar:
186
+ prediction = requests.get(backend_address,
187
+ params={"q": features})
188
+ st.code(f'Parameters sent: {features}')
189
+ st.write('Model predicts')
190
+ st.info(f'y = {prediction.text}')
191
+ """
192
+ )
193
+
194
+ show_api = st.checkbox('Show code of FastAPI backend')
195
+ if show_api:
196
+ st.code("""
197
+ # -*- coding: utf-8 -*-
198
+ # File: main.py
199
+ # Project: 'Homework #2 OTUS.ML.Advanced'
200
+ # Created by Gennady Matveev (gm@og.ly) on 04-01-2022.
201
+ # Copyright 2022. All rights reserved.
202
+
203
+ # Import libraries
204
+ import uvicorn
205
+ from atom import ATOMLoader
206
+ from fastapi import FastAPI, Query
207
+ import pandas as pd
208
+ from typing import List, Optional
209
+
210
+ cols = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach',
211
+ 'exang', 'oldpeak', 'slope', 'ca', 'thal']
212
+
213
+ atom = ATOMLoader("./models/atom20220104-32256", verbose=0)
214
+
215
+ # Initialize app
216
+ app = FastAPI()
217
+
218
+ # Routes
219
+ @app.get('/')
220
+ async def index():
221
+ return {"text": "Hello, fellow ML students"}
222
+
223
+
224
+ @app.get('/predict/')
225
+ async def predict(q: Optional[List[float]] = Query(None)):
226
+ dfx = pd.DataFrame([q], columns = cols)
227
+ prediction = atom.predict(dfx)
228
+ return int(prediction[0])
229
+
230
+
231
+ if __name__ == '__main__':
232
+ # port = int(os.environ.get("PORT", 8080))
233
+ port = int(os.environ.get("PORT", 8080))
234
+ uvicorn.run("main:app", host="0.0.0.0", port=port)
235
+ """
236
+ )
237
+
238
+ st.markdown("And, finally, classification model itself on [Colab](https://colab.research.google.com/github/oort77/otusmladvhw2-notebook/blob/main/otus_adv_hw2.ipynb)")
car_at_night.ico ADDED
requirements.txt ADDED
File without changes
sky.png ADDED