private_polymer_compound_prediction / build_gradio_graph.py
bndl's picture
Update build_gradio_graph.py
e49d9e7 verified
# import os
# import csv
import gradio as gr
# import tensorflow as tf
# import numpy as np
import pandas as pd
import yaml
import shap
import matplotlib.pyplot as plt
from inference_polymers_gnn import predict
import numpy as np
# from datetime import datetime
# import utils
# from huggingface_hub import Repository
# import itertools
# import time
# import cv2
# from prediction_coatings import predict
# from utils import predict, unpickle_file, scale_numerical, encode_categorical
def create_shap_plot(shap_values, df, num_targets=1):
# TODO improve shap interperter
plt.clf()
# shap.summary_plot(shap_values[0], feature_names=df_preprocessed.columns)
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
shap.summary_plot(shap_values[0], df, show=False, feature_names=df.columns, plot_size=(15, 15))
plt.subplot(1, 2, 2)
shap.summary_plot(shap_values[1], df, show=False, feature_names=df.columns, plot_size=(15, 15))
# plt.subplot(1,2,3)
# shap.summary_plot(shap_values[2], df_preprocessed, show=False, feature_names=df_preprocessed.columns, plot_size=(15, 15))
plt.tight_layout()
plt.subplots_adjust(wspace=2.0)
fig = plt.gcf()
return fig
def call_predict(inference_dict, cols_order, numerical_columns, target_columns):
"""
Encapsulates the predict function from utils to pass the config, and to put the data in the right format
"""
def predict_from_list(x_list):
df = pd.DataFrame([x_list], columns=cols_order)
print(df.shape)
print("lllllllllllllllll")
print(df)
print(".................")
y_pred = predict(df, model_path=inference_dict["model_path"])
# fig = create_shap_plot(shap_values, df_preprocessed, num_targets=len(target_columns))
print("$$$$$$$$$$$$$$$")
print(len(y_pred))
print(y_pred)
outputs = []
for i in range(len(target_columns)):
outputs += [y_pred[i][0]]
outputs += [np.round(np.random.uniform(2, 6), 1)]
# outputs += [fig]
return outputs
return lambda *x: predict_from_list(x)
def initialize_config(config_name):
"""
Loads the configuration and defines the color theme
"""
osium_theme_colors = gr.themes.Color(
c50="#e4f3fa", # Dataframe background cell content - light mode only
c100="#e4f3fa", # Top corner of clear button in light mode + markdown text in dark mode
c200="#a1c6db", # Component borders
c300="#FFFFFF", #
c400="#e4f3fa", # Footer text
c500="#0c1538", # Text of component headers in light mode only
c600="#a1c6db", # Top corner of button in dark mode
c700="#475383", # Button text in light mode + component borders in dark mode
c800="#0c1538", # Markdown text in light mode
c900="#a1c6db", # Background of dataframe - dark mode
c950="#0c1538",
) # Background in dark mode only
# secondary color used for highlight box content when typing in light mode, and download option in dark mode
# primary color used for login button in dark mode
osium_theme = gr.themes.Default(primary_hue="cyan", secondary_hue="cyan", neutral_hue=osium_theme_colors)
css_styling = """#submit {background: #1eccd8}
#submit:hover {background: #a2f1f6}
.output-image, .input-image, .image-preview {height: 250px !important}
.output-plot {height: 250px !important}
#interpretation {height: 250px !important}"""
with open(config_name, "r") as file:
config = yaml.safe_load(file)
input_cols_order = [col_name for section_dict in config["input_order"] for col_name in section_dict["keys"]]
numerical_columns = [
col_name
for col_name in config["input_mapping"].keys()
if config["input_mapping"][col_name]["comp_type"] == "Number"
]
example_inputs = [config["input_mapping"][col_name]["example"] for col_name in input_cols_order]
target_columns = [
col_name
for section_dict in config["output_order"]
for col_name in section_dict["keys"]
if not col_name.endswith("_uncertainty")
]
return config, input_cols_order, target_columns, numerical_columns, osium_theme, css_styling, example_inputs
def add_gradio_component(config_dict, component_key):
"""
Creates a gradio component for the component_key component, based on the config_dict dictionary of parameters
"""
if config_dict[component_key]["comp_type"] == "Text":
new_component = gr.Text(
label=config_dict[component_key]["label"], placeholder=config_dict[component_key]["label"]
)
elif config_dict[component_key]["comp_type"] == "Number":
new_component = gr.Number(
label=config_dict[component_key]["label"],
precision=3,
)
elif config_dict[component_key]["comp_type"] == "Dropdown":
new_component = gr.Dropdown(
label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"]
)
elif config_dict[component_key]["comp_type"] == "Image":
new_component = gr.Image(elem_classes="image-preview")
elif config_dict[component_key]["comp_type"] == "CheckboxGroup":
new_component = gr.CheckboxGroup(
label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"]
)
elif config_dict[component_key]["comp_type"] == "Plot":
new_component = gr.Plot(label=config_dict[component_key]["label"], type="matplotlib")
elif config_dict[component_key]["comp_type"] == "Dataframe":
new_component = gr.Dataframe(wrap=True, type="pandas")
elif config_dict[component_key]["comp_type"] == "Slider":
new_component = gr.Slider(label=config_dict[component_key]["label"], minimum=config_dict[component_key]["minimum"], maximum=config_dict[component_key]["maximum"], step=config_dict[component_key]["step"])
else:
print(
f"Found component type {config_dict[component_key]['comp_type']} for {component_key}, which is not supported"
)
new_component = None
return new_component
def create_gradio_interface(
input_order,
input_mapping,
output_order,
output_mapping,
example_inputs,
additional_markdown,
size,
osium_theme,
css_styling,
predict_fn,
inverse_design=False,
):
"""
Creates the gradio visual interface from the configuration file
"""
with gr.Blocks(css=css_styling, title=additional_markdown["page_title"], theme=osium_theme) as demo:
gr.Markdown(f"# <p style='text-align: center;'>{additional_markdown['main_title']}</p>")
gr.Markdown(additional_markdown["details"])
with gr.Row():
clear_button = gr.Button("Clear")
prediction_button = gr.Button("Predict", elem_id="submit")
input_list = []
output_list = []
with gr.Row():
# Input component section
with gr.Column(scale=size["input_column_scale"], min_width=size["input_column_min_width"]):
for _, section_dict in enumerate(input_order):
gr.Markdown(f"### {section_dict['markdown']}")
for _, col_name in enumerate(section_dict["keys"]):
input_component = add_gradio_component(input_mapping, col_name)
input_list.append(input_component)
# Output component section
with gr.Column():
with gr.Row():
for _, section_dict in enumerate(output_order):
with gr.Column():
gr.Markdown(f"### {section_dict['markdown']}")
for _, col_name in enumerate(section_dict["keys"]):
output_component = add_gradio_component(output_mapping, col_name)
output_list.append(output_component)
if not inverse_design:
# Currenly one plot contains all the interpretation figures
with gr.Row():
with gr.Column():
with gr.Row():
gr.Markdown(f"### {additional_markdown['interpretation']}")
with gr.Row():
output_interpretation = gr.Plot(label="Interpretation", type="matplotlib")
output_list.append(output_interpretation)
with gr.Row():
gr.Examples([example_inputs], input_list)
prediction_button.click(
fn=predict_fn,
inputs=input_list,
outputs=output_list,
show_progress=True,
)
clear_button.click(
lambda x: [gr.update(value=None)] * (len(input_list) + len(output_list)),
[],
input_list + output_list,
)
return demo