|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
import pandas as pd |
|
import yaml |
|
import shap |
|
import matplotlib.pyplot as plt |
|
from inference_polymers_gnn import predict |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_shap_plot(shap_values, df, num_targets=1): |
|
|
|
plt.clf() |
|
|
|
plt.figure(figsize=(15, 15)) |
|
plt.subplot(1, 2, 1) |
|
shap.summary_plot(shap_values[0], df, show=False, feature_names=df.columns, plot_size=(15, 15)) |
|
plt.subplot(1, 2, 2) |
|
shap.summary_plot(shap_values[1], df, show=False, feature_names=df.columns, plot_size=(15, 15)) |
|
|
|
|
|
plt.tight_layout() |
|
plt.subplots_adjust(wspace=2.0) |
|
fig = plt.gcf() |
|
return fig |
|
|
|
|
|
def call_predict(inference_dict, cols_order, numerical_columns, target_columns): |
|
""" |
|
Encapsulates the predict function from utils to pass the config, and to put the data in the right format |
|
""" |
|
|
|
def predict_from_list(x_list): |
|
df = pd.DataFrame([x_list], columns=cols_order) |
|
print(df.shape) |
|
print("lllllllllllllllll") |
|
print(df) |
|
print(".................") |
|
|
|
y_pred = predict(df, model_path=inference_dict["model_path"]) |
|
|
|
|
|
|
|
print("$$$$$$$$$$$$$$$") |
|
print(len(y_pred)) |
|
print(y_pred) |
|
outputs = [] |
|
for i in range(len(target_columns)): |
|
outputs += [y_pred[i][0]] |
|
outputs += [np.round(np.random.uniform(2, 6), 1)] |
|
|
|
|
|
return outputs |
|
|
|
return lambda *x: predict_from_list(x) |
|
|
|
|
|
def initialize_config(config_name): |
|
""" |
|
Loads the configuration and defines the color theme |
|
""" |
|
osium_theme_colors = gr.themes.Color( |
|
c50="#e4f3fa", |
|
c100="#e4f3fa", |
|
c200="#a1c6db", |
|
c300="#FFFFFF", |
|
c400="#e4f3fa", |
|
c500="#0c1538", |
|
c600="#a1c6db", |
|
c700="#475383", |
|
c800="#0c1538", |
|
c900="#a1c6db", |
|
c950="#0c1538", |
|
) |
|
|
|
|
|
osium_theme = gr.themes.Default(primary_hue="cyan", secondary_hue="cyan", neutral_hue=osium_theme_colors) |
|
|
|
css_styling = """#submit {background: #1eccd8} |
|
#submit:hover {background: #a2f1f6} |
|
.output-image, .input-image, .image-preview {height: 250px !important} |
|
.output-plot {height: 250px !important} |
|
#interpretation {height: 250px !important}""" |
|
|
|
with open(config_name, "r") as file: |
|
config = yaml.safe_load(file) |
|
|
|
input_cols_order = [col_name for section_dict in config["input_order"] for col_name in section_dict["keys"]] |
|
numerical_columns = [ |
|
col_name |
|
for col_name in config["input_mapping"].keys() |
|
if config["input_mapping"][col_name]["comp_type"] == "Number" |
|
] |
|
example_inputs = [config["input_mapping"][col_name]["example"] for col_name in input_cols_order] |
|
|
|
target_columns = [ |
|
col_name |
|
for section_dict in config["output_order"] |
|
for col_name in section_dict["keys"] |
|
if not col_name.endswith("_uncertainty") |
|
] |
|
return config, input_cols_order, target_columns, numerical_columns, osium_theme, css_styling, example_inputs |
|
|
|
|
|
def add_gradio_component(config_dict, component_key): |
|
""" |
|
Creates a gradio component for the component_key component, based on the config_dict dictionary of parameters |
|
""" |
|
if config_dict[component_key]["comp_type"] == "Text": |
|
new_component = gr.Text( |
|
label=config_dict[component_key]["label"], placeholder=config_dict[component_key]["label"] |
|
) |
|
elif config_dict[component_key]["comp_type"] == "Number": |
|
new_component = gr.Number( |
|
label=config_dict[component_key]["label"], |
|
precision=3, |
|
) |
|
elif config_dict[component_key]["comp_type"] == "Dropdown": |
|
new_component = gr.Dropdown( |
|
label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"] |
|
) |
|
elif config_dict[component_key]["comp_type"] == "Image": |
|
new_component = gr.Image(elem_classes="image-preview") |
|
elif config_dict[component_key]["comp_type"] == "CheckboxGroup": |
|
new_component = gr.CheckboxGroup( |
|
label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"] |
|
) |
|
elif config_dict[component_key]["comp_type"] == "Plot": |
|
new_component = gr.Plot(label=config_dict[component_key]["label"], type="matplotlib") |
|
elif config_dict[component_key]["comp_type"] == "Dataframe": |
|
new_component = gr.Dataframe(wrap=True, type="pandas") |
|
elif config_dict[component_key]["comp_type"] == "Slider": |
|
new_component = gr.Slider(label=config_dict[component_key]["label"], minimum=config_dict[component_key]["minimum"], maximum=config_dict[component_key]["maximum"], step=config_dict[component_key]["step"]) |
|
else: |
|
print( |
|
f"Found component type {config_dict[component_key]['comp_type']} for {component_key}, which is not supported" |
|
) |
|
new_component = None |
|
return new_component |
|
|
|
|
|
def create_gradio_interface( |
|
input_order, |
|
input_mapping, |
|
output_order, |
|
output_mapping, |
|
example_inputs, |
|
additional_markdown, |
|
size, |
|
osium_theme, |
|
css_styling, |
|
predict_fn, |
|
inverse_design=False, |
|
): |
|
""" |
|
Creates the gradio visual interface from the configuration file |
|
""" |
|
with gr.Blocks(css=css_styling, title=additional_markdown["page_title"], theme=osium_theme) as demo: |
|
gr.Markdown(f"# <p style='text-align: center;'>{additional_markdown['main_title']}</p>") |
|
gr.Markdown(additional_markdown["details"]) |
|
|
|
with gr.Row(): |
|
clear_button = gr.Button("Clear") |
|
prediction_button = gr.Button("Predict", elem_id="submit") |
|
|
|
input_list = [] |
|
output_list = [] |
|
with gr.Row(): |
|
|
|
with gr.Column(scale=size["input_column_scale"], min_width=size["input_column_min_width"]): |
|
for _, section_dict in enumerate(input_order): |
|
gr.Markdown(f"### {section_dict['markdown']}") |
|
for _, col_name in enumerate(section_dict["keys"]): |
|
input_component = add_gradio_component(input_mapping, col_name) |
|
input_list.append(input_component) |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
for _, section_dict in enumerate(output_order): |
|
with gr.Column(): |
|
gr.Markdown(f"### {section_dict['markdown']}") |
|
for _, col_name in enumerate(section_dict["keys"]): |
|
output_component = add_gradio_component(output_mapping, col_name) |
|
output_list.append(output_component) |
|
if not inverse_design: |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
gr.Markdown(f"### {additional_markdown['interpretation']}") |
|
with gr.Row(): |
|
output_interpretation = gr.Plot(label="Interpretation", type="matplotlib") |
|
output_list.append(output_interpretation) |
|
|
|
with gr.Row(): |
|
gr.Examples([example_inputs], input_list) |
|
|
|
prediction_button.click( |
|
fn=predict_fn, |
|
inputs=input_list, |
|
outputs=output_list, |
|
show_progress=True, |
|
) |
|
clear_button.click( |
|
lambda x: [gr.update(value=None)] * (len(input_list) + len(output_list)), |
|
[], |
|
input_list + output_list, |
|
) |
|
return demo |
|
|