Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import re | |
import os | |
import json | |
import yaml | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import plotnine as p9 | |
import sys | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append('..') | |
sys.path.append('.') | |
from about import * | |
global data_component, filter_component | |
def get_baseline_df(selected_methods, selected_metrics): | |
df = pd.read_csv(CSV_RESULT_PATH) | |
present_columns = ["method_name"] + selected_metrics | |
df = df[df['method_name'].isin(selected_methods)][present_columns] | |
return df | |
def get_method_color(method): | |
return color_dict.get(method, 'black') # If method is not in color_dict, use black | |
def set_colors_and_marks_for_representation_groups(ax): | |
for label in ax.get_xticklabels(): | |
text = label.get_text() | |
color = group_color_dict.get(text, 'black') # Default to black if label not in dict | |
label.set_color(color) | |
label.set_fontweight('bold') | |
# Add a caret symbol to specific labels | |
if text in {'MUT2VEC', 'PFAM', 'GENE2VEC', 'BERT-PFAM'}: | |
label.set_text(f"^ {text}") | |
def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric): | |
if benchmark_type == 'flexible': | |
# Use general visualizer logic | |
return general_visualizer_plot(methods_selected, x_metric=x_metric, y_metric=y_metric) | |
elif benchmark_type == 'similarity': | |
title = f"{x_metric} vs {y_metric}" | |
return draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title) | |
elif benchmark_type == 'function': | |
return plot_function_results("./data/function_results.csv", x_metric, y_metric, methods_selected) | |
elif benchmark_type == 'family': | |
return plot_family_results("./data/family_results.csv", methods_selected, x_metric, save_path="./plot_images") | |
elif benchmark_type == "affinity": | |
return plot_affinity_results("./data/affinity_results.csv", methods_selected, x_metric, save_path="./plot_images") | |
def general_visualizer(methods_selected, x_metric, y_metric): | |
df = pd.read_csv(CSV_RESULT_PATH) | |
filtered_df = df[df['method_name'].isin(methods_selected)] | |
# Create a Seaborn lineplot with method as hue | |
plt.figure(figsize=(10, 8)) # Increase figure size | |
sns.lineplot( | |
data=filtered_df, | |
x=x_metric, | |
y=y_metric, | |
hue="method_name", # Different colors for different methods | |
marker="o", # Add markers to the line plot | |
) | |
# Add labels and title | |
plt.xlabel(x_metric) | |
plt.ylabel(y_metric) | |
plt.title(f'{y_metric} vs {x_metric} for selected methods') | |
plt.grid(True) | |
# Save the plot to display it in Gradio | |
plot_path = "plot.png" | |
plt.savefig(plot_path) | |
plt.close() | |
return plot_path | |
def plot_similarity_results(methods_selected, x_metric, y_metric, title): | |
df = pd.read_csv(CSV_RESULT_PATH) | |
# Filter the dataframe based on selected methods | |
filtered_df = df[df['method_name'].isin(methods_selected)] | |
def get_method_color(method): | |
return color_dict.get(method.upper(), 'black') | |
# Add a new column to the dataframe for the color | |
filtered_df['color'] = filtered_df['method_name'].apply(get_method_color) | |
adjust_text_dict = { | |
'expand_text': (1.15, 1.4), 'expand_points': (1.15, 1.25), 'expand_objects': (1.05, 1.5), | |
'expand_align': (1.05, 1.2), 'autoalign': 'xy', 'va': 'center', 'ha': 'center', | |
'force_text': (.0, 1.), 'force_objects': (.0, 1.), | |
'lim': 500000, 'precision': 1., 'avoid_points': True, 'avoid_text': True | |
} | |
# Create the scatter plot using plotnine (ggplot) | |
g = (p9.ggplot(data=filtered_df, | |
mapping=p9.aes(x=x_metric, # Use the selected x_metric | |
y=y_metric, # Use the selected y_metric | |
color='color', # Use the dynamically generated color | |
label='method_name')) # Label each point by the method name | |
+ p9.geom_point(size=3) # Add points with no jitter, set point size | |
+ p9.geom_text(nudge_y=0.02, size=8) # Add method names as labels, nudge slightly above the points | |
+ p9.labs(title=title, x=f"{x_metric}", y=f"{y_metric}") # Dynamic labels for X and Y axes | |
+ p9.scale_color_identity() # Use colors directly from the dataframe | |
+ p9.theme(legend_position='none', | |
figure_size=(8, 8), # Set figure size | |
axis_text=p9.element_text(size=10), | |
axis_title_x=p9.element_text(size=12), | |
axis_title_y=p9.element_text(size=12)) | |
) | |
# Save the plot as an image | |
save_path = "./plot_images" # Ensure this folder exists or adjust the path | |
os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist | |
filename = os.path.join(save_path, title.replace(" ", "_") + "_Similarity_Scatter.png") | |
g.save(filename=filename, dpi=400) | |
return filename | |
def plot_function_results(file_path, aspect, metric, method_names): | |
# Load data | |
df = pd.read_csv(file_path) | |
# Filter for selected methods | |
df = df[df['Method'].isin(method_names)] | |
# Filter columns for specified aspect and metric | |
columns_to_plot = [col for col in df.columns if col.startswith(f"{aspect}_") and col.endswith(f"_{metric}")] | |
df = df[['Method'] + columns_to_plot] | |
df.set_index('Method', inplace=True) | |
# Create clustermap | |
g = sns.clustermap(df, annot=True, cmap="YlGnBu", row_cluster=False, col_cluster=False, figsize=(15, 15)) | |
# Get heatmap axis and customize labels | |
ax = g.ax_heatmap | |
ax.set_xlabel("") | |
ax.set_ylabel("") | |
# Apply color and caret adjustments to x-axis labels | |
set_colors_and_marks_for_representation_groups(ax) | |
# Save the plot as an image | |
save_path = "./plot_images" # Ensure this folder exists or adjust the path | |
os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist | |
filename = os.path.join(save_path, f"{aspect}_{metric}_heatmap.png") | |
plt.savefig(filename, dpi=400, bbox_inches='tight') | |
plt.close() # Close the plot to free memory | |
return filename | |
def plot_family_results(file_path, method_names, metric, save_path="./plot_images"): | |
# Load data | |
df = pd.read_csv(file_path) | |
# Filter by method names and selected metric columns | |
df = df[df['Method'].isin(method_names)] | |
metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")] | |
# Check if there are columns matching the selected metric | |
if not metric_columns: | |
print(f"No columns found for metric '{metric}'.") | |
return None | |
# Reshape data for plotting | |
df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value') | |
df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index | |
# Set up the plot | |
sns.set(rc={'figure.figsize': (13.7, 18.27)}) | |
sns.set_theme(style="whitegrid", color_codes=True) | |
ax = sns.boxplot(data=df_long, x='Value', y='Method', hue='Fold', whis=np.inf, orient="h") | |
# Customize x-axis and y-axis tickers and grid | |
ax.xaxis.set_major_locator(ticker.MultipleLocator(0.2)) | |
ax.get_xaxis().set_minor_locator(ticker.AutoMinorLocator()) | |
ax.get_yaxis().set_minor_locator(ticker.AutoMinorLocator()) | |
ax.grid(b=True, which='major', color='gainsboro', linewidth=1.0) | |
ax.grid(b=True, which='minor', color='whitesmoke', linewidth=0.5) | |
ax.set_xlim(0, 1) | |
# Draw dashed lines between different representations on y-axis | |
yticks = ax.get_yticks() | |
for ytick in yticks: | |
ax.hlines(ytick + 0.5, -0.1, 1, linestyles='dashed') | |
# Apply color settings to y-axis labels | |
set_colors_and_marks_for_representation_groups(ax) | |
# Ensure save directory exists | |
os.makedirs(save_path, exist_ok=True) | |
# Save the plot | |
filename = os.path.join(save_path, f"{metric}_family_results.png") | |
ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight') | |
plt.close() # Close the plot to free memory | |
return filename | |
def plot_affinity_results(file_path, method_names, metric, save_path="./plot_images"): | |
# Load the CSV data | |
df = pd.read_csv(file_path) | |
# Filter for selected methods | |
df = df[df['Method'].isin(method_names)] | |
# Gather columns related to the specified metric and validate | |
metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")] | |
if not metric_columns: | |
print(f"No columns found for metric '{metric}'.") | |
return None | |
# Reshape data for plotting | |
df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value') | |
df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index for sorting | |
# Set up the plot | |
sns.set(rc={'figure.figsize': (13.7, 8.27)}) | |
sns.set_theme(style="whitegrid", color_codes=True) | |
# Create a boxplot for the metric | |
ax = sns.boxplot(data=df_long, x='Value', y='Method', hue='Fold', whis=np.inf, orient="h") | |
# Customize x-axis and y-axis tickers and grid | |
ax.xaxis.set_major_locator(ticker.MultipleLocator(5)) | |
ax.get_xaxis().set_minor_locator(mpl.ticker.AutoMinorLocator()) | |
ax.get_yaxis().set_minor_locator(mpl.ticker.AutoMinorLocator()) | |
ax.grid(b=True, which='major', color='gainsboro', linewidth=1.0) | |
ax.grid(b=True, which='minor', color='whitesmoke', linewidth=0.5) | |
# Apply custom color settings to y-axis labels | |
set_colors_and_marks_for_representation_groups(ax) | |
# Ensure save path exists | |
os.makedirs(save_path, exist_ok=True) | |
# Save the plot | |
filename = os.path.join(save_path, f"{metric}_affinity_results.png") | |
ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight') | |
plt.close() # Close the plot to free memory | |
return filename | |
def update_metric_choices(benchmark_type): | |
if benchmark_type == 'similarity': | |
# Show x and y metric selectors for similarity | |
metric_names = benchmark_specific_metrics.get(benchmark_type, []) | |
return ( | |
gr.update(choices=metric_names, value=metric_names[0], visible=True), | |
gr.update(choices=metric_names, value=metric_names[1], visible=True), | |
gr.update(visible=False), gr.update(visible=False), | |
gr.update(visible=False), gr.update(visible=False) | |
) | |
elif benchmark_type == 'function': | |
# Show aspect and dataset type selectors for function | |
aspect_types = benchmark_specific_metrics[benchmark_type]['aspect_types'] | |
dataset_types = benchmark_specific_metrics[benchmark_type]['dataset_types'] | |
return ( | |
gr.update(visible=False), gr.update(visible=False), | |
gr.update(choices=aspect_types, value=aspect_types[0], visible=True), | |
gr.update(choices=dataset_types, value=dataset_types[0], visible=True), | |
gr.update(visible=False), gr.update(visible=False) | |
) | |
elif benchmark_type == 'family': | |
# Show dataset and metric selectors for family | |
datasets = benchmark_specific_metrics[benchmark_type]['datasets'] | |
metrics = benchmark_specific_metrics[benchmark_type]['metrics'] | |
return ( | |
gr.update(visible=False), gr.update(visible=False), | |
gr.update(visible=False), gr.update(visible=False), | |
gr.update(choices=datasets, value=datasets[0], visible=True), | |
gr.update(choices=metrics, value=metrics[0], visible=True) | |
) | |
elif benchmark_type == 'affinity': | |
# Show single metric selector for affinity | |
metrics = benchmark_specific_metrics[benchmark_type] | |
return ( | |
gr.update(visible=False), gr.update(visible=False), | |
gr.update(visible=False), gr.update(visible=False), | |
gr.update(visible=False), gr.update(choices=metrics, value=metrics[0], visible=True) | |
) | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |