import gradio as gr import joblib import numpy as np import pandas as pd import shap from matplotlib import pyplot as plt from utils.data_processor import DataProcessor from utils.model_predictor import ModelPredictor from utils.data import patients_data, key_to_display_name_and_value_conversion from huggingface_hub import hf_hub_download import joblib model = joblib.load( hf_hub_download("Proddis/pbc_complication_model", "RandomForestClassifier_trained_pipeline.joblib") ) categorical_names = joblib.load('resources/categorical_names.pkl') target_labels = joblib.load('resources/target_labels.pkl') selected_features = [] shap_explainer = shap.TreeExplainer(model.named_steps['RandomForestClassifier']) data_processor = DataProcessor(model, categorical_names, selected_features) predictor = ModelPredictor(model) labels_map = {0: "Transplant/Death", 1: "Survive"} plot_path = "shap_waterfall_plot.png" def select_and_predict(patient_selection): # Assuming 'patients_data' is a dict with patient profiles user_input = patients_data[patient_selection] # Simulating user input user_input_df = pd.DataFrame([user_input]) # Convert dict to DataFrame for processing # Process the input and predict prediction, probabilities = predictor.predict(user_input_df) # preprocess input for shap values preprocessed_input = data_processor.shap_and_eli5_custom_format(user_input_df) user_input_items = "".join([ f"
  • {display_name}: {converted_value}
  • " for key, value in user_input.items() for display_name, converted_value in [data_processor.convert_value(key, value)] ]) user_input_display = f"

    Patient Data

    " # Generate features list dynamically features = [key_to_display_name_and_value_conversion.get(key, (key.replace('_', ' ').title(), None))[0] for key in user_input.keys()] label = labels_map.get(int(np.argmax(probabilities))) # map data only for display reasons for shap value waterfall plot mapped_row = data_processor.apply_mapping_to_row(user_input_df.iloc[0]) # SHAP Explanation shap_values = shap_explainer.shap_values(preprocessed_input) shap_explanation = shap.Explanation(values=shap_values[0][0, :], base_values=shap_explainer.expected_value[0], data=mapped_row, feature_names=features) # Generate and save the SHAP waterfall plot shap.waterfall_plot(shap_explanation, max_display=len(user_input_df.columns), show=False) fig = plt.gcf() fig.set_size_inches(12, 8) # fig.suptitle(f'Prediction: {label}', fontsize=20, y=1.05) plt.savefig(plot_path, bbox_inches='tight') plt.close(fig) # Prepare the output proba_df = pd.DataFrame(probabilities, columns=labels_map.values()) proba_df = proba_df.applymap(lambda x: f"{x*100:.1f}%") proba_html = proba_df.to_html(classes='table table-striped', header="true", index=False) prediction_html = f"

    Prediction: {label}

    Probabilities: {proba_html}

    " return user_input_display, prediction_html, plot_path with gr.Blocks() as app: # title with gr.Row(): gr.Markdown("# Risk of Disease Complication in Biliary Cirrhosis Patients") # Using Markdown for the page title # select box and button with gr.Row(): with gr.Column(scale=2): # Try adjusting scale here dropdown = gr.Dropdown(list(patients_data.keys()), label="Select Patient Profile") btn = gr.Button("Predict") gr.Column([], scale=2) gr.Column([], scale=2) # input data and results with gr.Row(): with gr.Column(scale=1): # Try adjusting scale here user_input_output = gr.HTML() with gr.Column(scale=1): prediction_html = gr.HTML() with gr.Column(scale=2): # Try adjusting scale here gr.Markdown("# Risk Factors of Disease Complication") output_image = gr.Image(show_share_button=False) gr.Markdown( "Left arrows show what features tipping the scales to 'Survive', and right arrows show what features leaning towards 'Transplant/Death'.") btn.click(fn=select_and_predict, inputs=dropdown, outputs=[user_input_output, prediction_html, output_image]) app.launch()