legend1234 commited on
Commit
cf4c3c3
1 Parent(s): 9992ded

Attempt to incorporate session state

Browse files
Files changed (1) hide show
  1. app.py +136 -213
app.py CHANGED
@@ -16,15 +16,19 @@ from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors
16
  from streamlit_extras.let_it_rain import rain
17
  from streamlit_ketcher import st_ketcher
18
 
 
 
 
 
19
  st.set_page_config(
20
  page_title="BBB Permeability Prediction with Imbalanced Learning",
21
  # page_icon="🧊",
22
  layout="wide",
23
  # initial_sidebar_state="expanded",
24
  # menu_items={
25
- # 'Get Help': 'https://www.extremelycoolapp.com/help',
26
- # 'Report a bug': "https://www.extremelycoolapp.com/bug",
27
- # 'About': "# This is a header. This is an *extremely* cool app!"
28
  # }
29
  )
30
 
@@ -53,156 +57,19 @@ mol_features = None
53
  info_df = None
54
  results = None
55
  temp_file_path = None
56
-
57
-
58
- @st.cache_data
59
- def load_all_models():
60
- """Get b3clf fitted classifier"""
61
- clf_list = ["dtree", "knn", "logreg", "xgb"]
62
- sampling_list = [
63
- "borderline_SMOTE",
64
- "classic_ADASYN",
65
- "classic_RandUndersampling",
66
- "classic_SMOTE",
67
- "kmeans_SMOTE",
68
- "common",
69
- ]
70
-
71
- model_dict = {}
72
- package_name = "b3clf"
73
-
74
- for clf_str, sampling_str in it.product(clf_list, sampling_list):
75
- # joblib_fpath = os.path.join(
76
- # dirname, "pre_trained", "b3clf_{}_{}.joblib".format(clf_str, sampling_str))
77
- # pred_model = joblib.load(joblib_fpath)
78
- joblib_path_str = f"pre_trained/b3clf_{clf_str}_{sampling_str}.joblib"
79
- with pkg_resources.resource_stream(package_name, joblib_path_str) as f:
80
- pred_model = joblib.load(f)
81
-
82
- model_dict[clf_str + "_" + sampling_str] = pred_model
83
-
84
- return model_dict
85
-
86
-
87
- @st.cache_resource
88
- def predict_permeability(clf_str, sampling_str, mol_features, info_df, threshold="none"):
89
- """Compute permeability prediction for given feature data."""
90
- # load the model
91
- pred_model = load_all_models()[clf_str + "_" + sampling_str]
92
-
93
- # load the threshold data
94
- package_name = "b3clf"
95
- with pkg_resources.resource_stream(
96
- package_name, "data/B3clf_thresholds.xlsx"
97
- ) as f:
98
- df_thres = pd.read_excel(f, index_col=0, engine="openpyxl")
99
-
100
- # default threshold is 0.5
101
- label_pool = np.zeros(mol_features.shape[0], dtype=int)
102
-
103
- if type(mol_features) == pd.DataFrame:
104
- if mol_features.index.tolist() != info_df.index.tolist():
105
- raise ValueError(
106
- "Features_df and Info_df do not have the same index."
107
- )
108
-
109
- # get predicted probabilities
110
- info_df.loc[:, "B3clf_predicted_probability"] = pred_model.predict_proba(mol_features)[
111
- :, 1
112
- ]
113
- # get predicted label from probability using the threshold
114
- mask = np.greater_equal(
115
- info_df["B3clf_predicted_probability"].to_numpy(),
116
- # df_thres.loc[clf_str + "-" + sampling_str, threshold])
117
- df_thres.loc["xgb-classic_ADASYN", threshold],
118
- )
119
- label_pool[mask] = 1
120
-
121
- # save the predicted labels
122
- info_df["B3clf_predicted_label"] = label_pool
123
- info_df.reset_index(inplace=True)
124
-
125
- return info_df
126
-
127
-
128
- # @st.cache_resource
129
- def generate_predictions(
130
- input_fname: str = None,
131
- sep: str = "\s+|\t+",
132
- clf: str = "xgb",
133
- sampling: str = "classic_ADASYN",
134
- time_per_mol: int = 120,
135
- mol_features: pd.DataFrame = None,
136
- info_df: pd.DataFrame = None,
137
- ):
138
- """
139
- Generate predictions for a given input file.
140
- """
141
- if mol_features is None and info_df is None:
142
- # mol_tag = os.path.splitext(uploaded_file.name)[0]
143
- # uploaded_file = uploaded_file.read().decode("utf-8")
144
- mol_tag = os.path.basename(input_fname).split(".")[0]
145
- internal_sdf = f"{mol_tag}_optimized_3d.sdf"
146
-
147
- # Geometry optimization
148
- # Input:
149
- # * Either an SDF file with molecular geometries or a text file with SMILES strings
150
-
151
- geometry_optimize(input_fname=input_fname, output_sdf=internal_sdf, sep=sep)
152
-
153
- df_features = compute_descriptors(
154
- sdf_file=internal_sdf,
155
- excel_out=None,
156
- output_csv=None,
157
- timeout=None,
158
- time_per_molecule=time_per_mol,
159
- )
160
- # st.write(df_features)
161
-
162
- # Get computed descriptors
163
- mol_features, info_df = get_descriptors(df=df_features)
164
-
165
- # Select descriptors
166
- mol_features = select_descriptors(df=mol_features)
167
-
168
- # Scale descriptors
169
- mol_features.iloc[:, :] = scale_descriptors(df=mol_features)
170
-
171
- # this is problematic for using the same file for calculation
172
- if os.path.exists(internal_sdf) and keep_sdf == "no":
173
- os.remove(internal_sdf)
174
-
175
- # Get classifier
176
- # clf = get_clf(clf_str=clf, sampling_str=sampling)
177
-
178
- # Get classifier
179
- result_df = predict_permeability(
180
- clf_str=clf,
181
- sampling_str=sampling,
182
- mol_features=mol_features,
183
- info_df=info_df,
184
- threshold="none",
185
- )
186
-
187
- # Get classifier
188
- display_cols = [
189
- "ID",
190
- "SMILES",
191
- "B3clf_predicted_probability",
192
- "B3clf_predicted_label",
193
- ]
194
-
195
- result_df = result_df[
196
- [col for col in result_df.columns.to_list() if col in display_cols]
197
- ]
198
-
199
- return mol_features, info_df, result_df
200
-
201
 
202
  # Create the Streamlit app
203
  st.title(":blue[BBB Permeability Prediction with Imbalanced Learning]")
204
  info_column, upload_column = st.columns(2)
205
 
 
 
 
 
 
 
 
206
  # download sample files
207
  with info_column:
208
  st.subheader("About `B3clf`")
@@ -212,10 +79,10 @@ with info_column:
212
  `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. This project is supported by Digital Research Alliance of Canada (originally known as Compute Canada) and NSERC. This project is maintained by QC-Dev comminity. For further information and inquiries please contact us at qcdevs@gmail.com."""
213
  )
214
  st.text(" \n")
215
- # text_body = '''
216
  # `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf.
217
- # '''
218
- # st.markdown(f'<p align="justify">{text_body}</p>',
219
  # unsafe_allow_html=True)
220
 
221
  # image = Image.open("images/b3clf_workflow.png")
@@ -224,7 +91,7 @@ with info_column:
224
  # image_path = "images/b3clf_workflow.png"
225
  # image_width_percent = 80
226
  # info_column.markdown(
227
- # f'<img src="{image_path}" style="max-width: {image_width_percent}%; height: auto;">',
228
  # unsafe_allow_html=True
229
  # )
230
 
@@ -280,12 +147,42 @@ with upload_column:
280
  upload_col, _, submit_job_col, _ = st.columns((4, 0.05, 1, 0.05))
281
  # upload file column
282
  with upload_col:
283
- file = st.file_uploader(
 
 
 
 
 
 
 
 
 
 
284
  label="Upload a CSV, SDF, TXT or SMI file",
285
  type=["csv", "sdf", "txt", "smi"],
286
  help="Input molecule file only supports *.csv, *.sdf, *.txt and *.smi.",
287
  accept_multiple_files=False,
 
 
288
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  # submit job column
290
  with submit_job_col:
291
  st.text(" \n")
@@ -295,9 +192,9 @@ with upload_column:
295
  unsafe_allow_html=True,
296
  )
297
  submit_job_button = st.button(
298
- label="Submit Job", key="submit_job_button", type="secondary"
299
  )
300
- # submit_job_col.markdown("<div style='display: flex; justify-content: center;'>",
301
  # unsafe_allow_html=True)
302
  # submit_job_button = submit_job_col.button(
303
  # label="Submit job", key="submit_job_button", type="secondary"
@@ -329,69 +226,95 @@ with prediction_column:
329
  # placeholder_predictions.text("prediction")
330
 
331
 
 
 
 
332
  # Generate predictions when the user uploads a file
333
- if submit_job_button:
334
- if file and mol_features is None and info_df is None:
 
 
 
 
 
 
 
335
  temp_dir = tempfile.mkdtemp()
336
  # Create a temporary file path for the uploaded file
337
- temp_file_path = os.path.join(temp_dir, file.name)
338
  # Save the uploaded file to the temporary file path
339
  with open(temp_file_path, "wb") as temp_file:
340
- temp_file.write(file.read())
341
- # mol_features, results = generate_predictions(temp_file_path)
342
- mol_features, info_df, results = generate_predictions(
343
- input_fname=temp_file_path,
344
- sep="\s+|\t+",
345
- clf=classifiers_dict[classifier],
346
- sampling=resample_methods_dict[resampler],
347
- time_per_mol=120,
348
- mol_features=mol_features,
349
- info_df=info_df,
350
- )
351
- st.balloons()
352
-
353
- # feture table
354
- with feature_column:
355
- if mol_features is not None:
356
- selected_feature_rows = np.min(
357
- [mol_features.shape[0], pandas_display_options["line_limit"]]
358
  )
359
- st.dataframe(mol_features.iloc[:selected_feature_rows, :], hide_index=False)
360
- # placeholder_features.dataframe(mol_features, hide_index=False)
361
- feature_file_name = file.name.split(".")[0] + "_b3clf_features.csv"
362
- features_csv = mol_features.to_csv(index=True)
363
- st.download_button(
364
- "Download features as CSV",
365
- data=features_csv,
366
- file_name=feature_file_name,
 
 
 
 
367
  )
368
- # prediction table
369
- with prediction_column:
370
- # st.subheader("Predictions")
371
- if results is not None:
372
- # Display the predictions in a table
373
- selected_result_rows = np.min(
374
- [results.shape[0], pandas_display_options["line_limit"]]
375
- )
376
- results_df_display = results.iloc[
377
- :selected_result_rows, :
378
- ].style.format({"B3clf_predicted_probability": "{:.6f}".format})
379
- st.dataframe(results_df_display, hide_index=True)
380
- # Add a button to download the predictions as a CSV file
381
- predictions_csv = results.to_csv(index=True)
382
- results_file_name = file.name.split(".")[0] + "_b3clf_predictions.csv"
383
- st.download_button(
384
- "Download predictions as CSV",
385
- data=predictions_csv,
386
- file_name=results_file_name,
387
- )
388
- # indicate the success of the job
389
- # rain(
390
- # emoji="🎈",
391
- # font_size=54,
392
- # falling_speed=5,
393
- # animation_length=10,
394
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
 
397
  # hide footer
@@ -412,9 +335,9 @@ st.markdown(
412
  <script>
413
  window.dataLayer = window.dataLayer || [];
414
  function gtag(){dataLayer.push(arguments);}
415
- gtag('js', new Date());
416
 
417
- gtag('config', 'G-WG8QYRELP9');
418
  </script>
419
  """,
420
  unsafe_allow_html=True,
 
16
  from streamlit_extras.let_it_rain import rain
17
  from streamlit_ketcher import st_ketcher
18
 
19
+ from utils import generate_predictions, load_all_models
20
+
21
+ st.cache_data.clear()
22
+
23
  st.set_page_config(
24
  page_title="BBB Permeability Prediction with Imbalanced Learning",
25
  # page_icon="🧊",
26
  layout="wide",
27
  # initial_sidebar_state="expanded",
28
  # menu_items={
29
+ # "Get Help": "https://www.extremelycoolapp.com/help",
30
+ # "Report a bug": "https://www.extremelycoolapp.com/bug",
31
+ # "About": "# This is a header. This is an *extremely* cool app!"
32
  # }
33
  )
34
 
 
57
  info_df = None
58
  results = None
59
  temp_file_path = None
60
+ all_models = load_all_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # Create the Streamlit app
63
  st.title(":blue[BBB Permeability Prediction with Imbalanced Learning]")
64
  info_column, upload_column = st.columns(2)
65
 
66
+ # inatialize the molecule features and info dataframe session state
67
+ if "mol_features" not in st.session_state:
68
+ st.session_state.mol_features = None
69
+ if "info_df" not in st.session_state:
70
+ st.session_state.info_df = None
71
+
72
+
73
  # download sample files
74
  with info_column:
75
  st.subheader("About `B3clf`")
 
79
  `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. This project is supported by Digital Research Alliance of Canada (originally known as Compute Canada) and NSERC. This project is maintained by QC-Dev comminity. For further information and inquiries please contact us at qcdevs@gmail.com."""
80
  )
81
  st.text(" \n")
82
+ # text_body = """
83
  # `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf.
84
+ # """
85
+ # st.markdown(f"<p align="justify">{text_body}</p>",
86
  # unsafe_allow_html=True)
87
 
88
  # image = Image.open("images/b3clf_workflow.png")
 
91
  # image_path = "images/b3clf_workflow.png"
92
  # image_width_percent = 80
93
  # info_column.markdown(
94
+ # f"<img src="{image_path}" style="max-width: {image_width_percent}%; height: auto;">",
95
  # unsafe_allow_html=True
96
  # )
97
 
 
147
  upload_col, _, submit_job_col, _ = st.columns((4, 0.05, 1, 0.05))
148
  # upload file column
149
  with upload_col:
150
+ # session state tracking of the file uploader
151
+ if "uploaded_file" not in st.session_state:
152
+ st.session_state.uploaded_file = None
153
+ if "uploaded_file_changed" not in st.session_state:
154
+ st.session_state.uploaded_file_changed = False
155
+
156
+ # def update_uploader_session_info():
157
+ # """Update the session state of the file uploader."""
158
+ # st.session_state.uploaded_file = uploaded_file
159
+
160
+ uploaded_file = st.file_uploader(
161
  label="Upload a CSV, SDF, TXT or SMI file",
162
  type=["csv", "sdf", "txt", "smi"],
163
  help="Input molecule file only supports *.csv, *.sdf, *.txt and *.smi.",
164
  accept_multiple_files=False,
165
+ # key="uploaded_file",
166
+ # on_change=update_uploader_session_info,
167
  )
168
+
169
+ if uploaded_file:
170
+ # st.write(f"the uploaded file: {uploaded_file}")
171
+ # when new file is uploaded is different from the previous one
172
+ if st.session_state.uploaded_file != uploaded_file:
173
+ st.session_state.uploaded_file_changed = True
174
+ else:
175
+ st.session_state.uploaded_file_changed = False
176
+ st.session_state.uploaded_file = uploaded_file
177
+ # when new file is the same as the previous one
178
+ # else:
179
+ # st.session_state.uploaded_file_changed = False
180
+ # st.session_state.uploaded_file = uploaded_file
181
+
182
+ # set session state for the file uploader
183
+ # st.write(f"the state of uploaded file: {st.session_state.uploaded_file}")
184
+ # st.write(f"the state of uploaded file changed: {st.session_state.uploaded_file_changed}")
185
+
186
  # submit job column
187
  with submit_job_col:
188
  st.text(" \n")
 
192
  unsafe_allow_html=True,
193
  )
194
  submit_job_button = st.button(
195
+ label="Submit Job", type="secondary", key="job_button"
196
  )
197
+ # submit_job_col.markdown("<div style="display: flex; justify-content: center;">",
198
  # unsafe_allow_html=True)
199
  # submit_job_button = submit_job_col.button(
200
  # label="Submit job", key="submit_job_button", type="secondary"
 
226
  # placeholder_predictions.text("prediction")
227
 
228
 
229
+ st.write(
230
+ f"the state of uploaded file changed before checking: {st.session_state.uploaded_file_changed}"
231
+ )
232
  # Generate predictions when the user uploads a file
233
+ # if submit_job_button:
234
+ print(st.session_state)
235
+ if "job_button" in st.session_state:
236
+ # when new file is uploaded
237
+ # update_uploader_session_info()
238
+ st.write(
239
+ f"the state of uploaded file changed after checking: {st.session_state.uploaded_file_changed}"
240
+ )
241
+ if st.session_state.uploaded_file_changed:
242
  temp_dir = tempfile.mkdtemp()
243
  # Create a temporary file path for the uploaded file
244
+ temp_file_path = os.path.join(temp_dir, uploaded_file.name)
245
  # Save the uploaded file to the temporary file path
246
  with open(temp_file_path, "wb") as temp_file:
247
+ temp_file.write(uploaded_file.read())
248
+
249
+ mol_features, info_df, results = generate_predictions(
250
+ input_fname=temp_file_path,
251
+ sep="\s+|\t+",
252
+ clf=classifiers_dict[classifier],
253
+ _models_dict=all_models,
254
+ sampling=resample_methods_dict[resampler],
255
+ time_per_mol=120,
256
+ mol_features=None,
257
+ info_df=None,
 
 
 
 
 
 
 
258
  )
259
+ st.session_state.mol_features = mol_features
260
+ st.session_state.info_df = info_df
261
+ else:
262
+ mol_features, info_df, results = generate_predictions(
263
+ input_fname=None,
264
+ sep="\s+|\t+",
265
+ clf=classifiers_dict[classifier],
266
+ _models_dict=all_models,
267
+ sampling=resample_methods_dict[resampler],
268
+ time_per_mol=120,
269
+ mol_features=st.session_state.mol_features,
270
+ info_df=st.session_state.info_df,
271
  )
272
+
273
+ # feture table
274
+ with feature_column:
275
+ if mol_features is not None:
276
+ selected_feature_rows = np.min(
277
+ [mol_features.shape[0], pandas_display_options["line_limit"]]
278
+ )
279
+ st.dataframe(mol_features.iloc[:selected_feature_rows, :], hide_index=False)
280
+ # placeholder_features.dataframe(mol_features, hide_index=False)
281
+ feature_file_name = uploaded_file.name.split(".")[0] + "_b3clf_features.csv"
282
+ features_csv = mol_features.to_csv(index=True)
283
+ st.download_button(
284
+ "Download features as CSV",
285
+ data=features_csv,
286
+ file_name=feature_file_name,
287
+ )
288
+ # prediction table
289
+ with prediction_column:
290
+ # st.subheader("Predictions")
291
+ if results is not None:
292
+ # Display the predictions in a table
293
+ selected_result_rows = np.min(
294
+ [results.shape[0], pandas_display_options["line_limit"]]
295
+ )
296
+ results_df_display = results.iloc[:selected_result_rows, :].style.format(
297
+ {"B3clf_predicted_probability": "{:.6f}".format}
298
+ )
299
+ st.dataframe(results_df_display, hide_index=True)
300
+ # Add a button to download the predictions as a CSV file
301
+ predictions_csv = results.to_csv(index=True)
302
+ results_file_name = (
303
+ uploaded_file.name.split(".")[0] + "_b3clf_predictions.csv"
304
+ )
305
+ st.download_button(
306
+ "Download predictions as CSV",
307
+ data=predictions_csv,
308
+ file_name=results_file_name,
309
+ )
310
+ # indicate the success of the job
311
+ # rain(
312
+ # emoji="🎈",
313
+ # font_size=54,
314
+ # falling_speed=5,
315
+ # animation_length=10,
316
+ # )
317
+ st.balloons()
318
 
319
 
320
  # hide footer
 
335
  <script>
336
  window.dataLayer = window.dataLayer || [];
337
  function gtag(){dataLayer.push(arguments);}
338
+ gtag("js", new Date());
339
 
340
+ gtag("config", "G-WG8QYRELP9");
341
  </script>
342
  """,
343
  unsafe_allow_html=True,