michalisG commited on
Commit
acea606
1 Parent(s): 74f37f1

Add application file-1

Browse files
Files changed (2) hide show
  1. app.py +65 -0
  2. config.json +35 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import shap
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import joblib
6
+ import matplotlib.pyplot as plt
7
+ #
8
+
9
+ from utils.data_processor import DataProcessor
10
+ from utils.model_predictor import ModelPredictor
11
+ from utils.user_input_features_collector import UserInputDataCollector
12
+
13
+ model = joblib.load('resources/model.joblib')
14
+ categorical_names = joblib.load('resources/categorical_names.pkl')
15
+ target_labels = joblib.load('resources/target_labels.pkl')
16
+ selected_features = []
17
+
18
+ shap_explainer = shap.TreeExplainer(model.named_steps['RandomForestClassifier'])
19
+ data_processor = DataProcessor(model, categorical_names, selected_features)
20
+ predictor = ModelPredictor(model)
21
+
22
+ st.write("### Enter Patient Information for Diagnosis Prediction")
23
+ data = UserInputDataCollector.user_input_features()
24
+ user_input = pd.DataFrame(data, index=[0])
25
+
26
+ st.write("#### Patient Data")
27
+ st.write(user_input)
28
+
29
+ # In your Streamlit app, where you handle the "Predict" button:
30
+ if st.button("Predict"):
31
+ prediction, probabilities = predictor.predict(user_input)
32
+ col1, col2 = st.columns(2)
33
+ labels_map = {0: "Transplant/Death", 1: "Survive"}
34
+ label = labels_map.get(int(np.argmax(probabilities)))
35
+
36
+ # with col1:
37
+ # st.subheader("Prediction")
38
+ # st.write(label)
39
+ #
40
+ # with col2:
41
+ st.subheader("Prediction Probabilities")
42
+ # Create a DataFrame for the probabilities to display them in a more readable format
43
+ proba_df = pd.DataFrame(probabilities, columns=labels_map.values())
44
+ st.dataframe(proba_df) # Using st.dataframe to make it more interact
45
+
46
+ i = 0
47
+ preprocessed_input = data_processor.shap_and_eli5_custom_format(user_input)
48
+ shap_values = shap_explainer.shap_values(preprocessed_input)
49
+ # np.argmax(probabilities)
50
+ shap_explanation = shap.Explanation(values=shap_values[np.argmax(probabilities)][0, :],
51
+ base_values=shap_explainer.expected_value[np.argmax(probabilities)],
52
+ data=user_input.iloc[0, :],
53
+ feature_names=user_input.columns.tolist())
54
+
55
+ # Generate the SHAP waterfall plot
56
+ shap.plots.waterfall(shap_explanation, max_display=len(user_input.columns.tolist()), show=False)
57
+ # After generating the SHAP plot, grab the current figure
58
+ fig = plt.gcf()
59
+ fig.set_size_inches(10, 7, forward=True)
60
+ # Optionally, adjust the plot title or other properties here
61
+ fig.suptitle(f'Prediction: {label}', fontsize=20, y=1.05)
62
+ # Display the figure in Streamlit, passing it explicitly to ensure thread safety
63
+ st.pyplot(fig)
64
+ # Reset the default plot size if necessary
65
+ plt.rcParams['figure.figsize'] = plt.rcParamsDefault['figure.figsize']
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "RandomForestClassifier",
3
+ "expected_features": [
4
+ "age",
5
+ "sex",
6
+ "serum_bilirubin",
7
+ "serum_cholesterol",
8
+ "albumin",
9
+ "alkaline_phosphatase",
10
+ "SGOT",
11
+ "platelets",
12
+ "prothrombin_time"
13
+ ],
14
+ "categorical_features": [
15
+ "drug",
16
+ "sex",
17
+ "presence_of_ascites",
18
+ "presence_of_hepatomegaly",
19
+ "presence_of_spiders",
20
+ "presence_of_edema"
21
+ ],
22
+ "model_parameters": {
23
+ "criterion": "entropy",
24
+ "max_features": 0.1,
25
+ "min_samples_split": 8,
26
+ "min_samples_leaf": 6,
27
+ "bootstrap": true
28
+ },
29
+ "version": "1.0",
30
+ "preprocessing": {
31
+ "numerical": "median imputation and scaling",
32
+ "categorical": "one-hot encoding",
33
+ "ordinal": "label encoding"
34
+ }
35
+ }