mgyigit commited on
Commit
c0756c5
1 Parent(s): b308ad4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -97
app.py CHANGED
@@ -15,103 +15,6 @@ from src.saving_utils import *
15
  from src.vis_utils import *
16
  from src.bin.PROBE import run_probe
17
 
18
- global data_component, filter_component
19
-
20
- def get_method_color(method):
21
- return color_dict.get(method, 'black') # If method is not in color_dict, use black
22
-
23
- def draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title):
24
- df = pd.read_csv(CSV_RESULT_PATH)
25
- # Filter the dataframe based on selected methods
26
- filtered_df = df[df['method_name'].isin(methods_selected)]
27
-
28
- def get_method_color(method):
29
- return color_dict.get(method.upper(), 'black')
30
-
31
- # Add a new column to the dataframe for the color
32
- filtered_df['color'] = filtered_df['method_name'].apply(get_method_color)
33
-
34
- adjust_text_dict = {
35
- 'expand_text': (1.15, 1.4), 'expand_points': (1.15, 1.25), 'expand_objects': (1.05, 1.5),
36
- 'expand_align': (1.05, 1.2), 'autoalign': 'xy', 'va': 'center', 'ha': 'center',
37
- 'force_text': (.0, 1.), 'force_objects': (.0, 1.),
38
- 'lim': 500000, 'precision': 1., 'avoid_points': True, 'avoid_text': True
39
- }
40
-
41
- # Create the scatter plot using plotnine (ggplot)
42
- g = (p9.ggplot(data=filtered_df,
43
- mapping=p9.aes(x=x_metric, # Use the selected x_metric
44
- y=y_metric, # Use the selected y_metric
45
- color='color', # Use the dynamically generated color
46
- label='method_names')) # Label each point by the method name
47
- + p9.geom_point(size=3) # Add points with no jitter, set point size
48
- + p9.geom_text(nudge_y=0.02, size=8) # Add method names as labels, nudge slightly above the points
49
- + p9.labs(title=title, x=f"{x_metric}", y=f"{y_metric}") # Dynamic labels for X and Y axes
50
- + p9.scale_color_identity() # Use colors directly from the dataframe
51
- + p9.theme(legend_position='none',
52
- figure_size=(8, 8), # Set figure size
53
- axis_text=p9.element_text(size=10),
54
- axis_title_x=p9.element_text(size=12),
55
- axis_title_y=p9.element_text(size=12))
56
- )
57
-
58
- # Save the plot as an image
59
- save_path = "./plot_images" # Ensure this folder exists or adjust the path
60
- os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
61
- filename = os.path.join(save_path, title.replace(" ", "_") + "_Similarity_Scatter.png")
62
-
63
- g.save(filename=filename, dpi=400)
64
-
65
- return filename
66
-
67
- def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric):
68
- if benchmark_type == 'flexible':
69
- # Use general visualizer logic
70
- return general_visualizer_plot(methods_selected, x_metric=x_metric, y_metric=y_metric)
71
- elif benchmark_type == 'similarity':
72
- title = f"{x_metric} vs {y_metric}"
73
- return draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title)
74
- elif benchmark_type == 'Benchmark 3':
75
- return benchmark_3_plot(x_metric, y_metric)
76
- elif benchmark_type == 'Benchmark 4':
77
- return benchmark_4_plot(x_metric, y_metric)
78
- else:
79
- return "Invalid benchmark type selected."
80
-
81
-
82
- def get_baseline_df(selected_methods, selected_metrics):
83
- df = pd.read_csv(CSV_RESULT_PATH)
84
- present_columns = ["method_name"] + selected_metrics
85
- df = df[df['method_name'].isin(selected_methods)][present_columns]
86
- return df
87
-
88
- def general_visualizer(methods_selected, x_metric, y_metric):
89
- df = pd.read_csv(CSV_RESULT_PATH)
90
- filtered_df = df[df['method_name'].isin(methods_selected)]
91
-
92
- # Create a Seaborn lineplot with method as hue
93
- plt.figure(figsize=(10, 8)) # Increase figure size
94
- sns.lineplot(
95
- data=filtered_df,
96
- x=x_metric,
97
- y=y_metric,
98
- hue="method_name", # Different colors for different methods
99
- marker="o", # Add markers to the line plot
100
- )
101
-
102
- # Add labels and title
103
- plt.xlabel(x_metric)
104
- plt.ylabel(y_metric)
105
- plt.title(f'{y_metric} vs {x_metric} for selected methods')
106
- plt.grid(True)
107
-
108
- # Save the plot to display it in Gradio
109
- plot_path = "plot.png"
110
- plt.savefig(plot_path)
111
- plt.close()
112
-
113
- return plot_path
114
-
115
  def add_new_eval(
116
  human_file,
117
  skempi_file,
 
15
  from src.vis_utils import *
16
  from src.bin.PROBE import run_probe
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def add_new_eval(
19
  human_file,
20
  skempi_file,