mgyigit commited on
Commit
666a6e2
·
verified ·
1 Parent(s): 83ac1b8

Update src/vis_utils.py

Browse files
Files changed (1) hide show
  1. src/vis_utils.py +14 -7
src/vis_utils.py CHANGED
@@ -183,6 +183,8 @@ def plot_family_results(methods_selected, dataset, metric, family_path="/tmp/fam
183
  df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
184
  df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index
185
 
 
 
186
  # Set up the plot
187
  sns.set(rc={'figure.figsize': (13.7, 18.27)})
188
  sns.set_theme(style="whitegrid", color_codes=True)
@@ -214,23 +216,25 @@ def plot_family_results(methods_selected, dataset, metric, family_path="/tmp/fam
214
 
215
  return filename
216
 
217
- def plot_affinity_results(file_path, method_names, metric, save_path="./plot_images"):
218
- # Load the CSV data
219
- df = pd.read_csv(file_path)
 
 
 
220
 
221
  # Filter for selected methods
222
  df = df[df['Method'].isin(method_names)]
223
 
224
  # Gather columns related to the specified metric and validate
225
  metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")]
226
- if not metric_columns:
227
- print(f"No columns found for metric '{metric}'.")
228
- return None
229
 
230
  # Reshape data for plotting
231
  df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
232
  df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index for sorting
233
 
 
 
234
  # Set up the plot
235
  sns.set(rc={'figure.figsize': (13.7, 8.27)})
236
  sns.set_theme(style="whitegrid", color_codes=True)
@@ -246,12 +250,15 @@ def plot_affinity_results(file_path, method_names, metric, save_path="./plot_ima
246
  ax.grid(b=True, which='minor', color='whitesmoke', linewidth=0.5)
247
 
248
  # Apply custom color settings to y-axis labels
249
- set_colors_and_marks_for_representation_groups(ax)
 
 
250
 
251
  # Ensure save path exists
252
  os.makedirs(save_path, exist_ok=True)
253
 
254
  # Save the plot
 
255
  filename = os.path.join(save_path, f"{metric}_affinity_results.png")
256
  ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight')
257
  plt.close() # Close the plot to free memory
 
183
  df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
184
  df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index
185
 
186
+ df = df.fillna(0)
187
+
188
  # Set up the plot
189
  sns.set(rc={'figure.figsize': (13.7, 18.27)})
190
  sns.set_theme(style="whitegrid", color_codes=True)
 
216
 
217
  return filename
218
 
219
+ def plot_affinity_results(method_names, metric, affinity_path="/tmp/affinity_results.csv"):
220
+ if not os.path.exists(affinity_path):
221
+ benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later
222
+ download_from_hub(benchmark_types)
223
+
224
+ df = pd.read_csv(affinity_path)
225
 
226
  # Filter for selected methods
227
  df = df[df['Method'].isin(method_names)]
228
 
229
  # Gather columns related to the specified metric and validate
230
  metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")]
 
 
 
231
 
232
  # Reshape data for plotting
233
  df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
234
  df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index for sorting
235
 
236
+ df = df.fillna(0)
237
+
238
  # Set up the plot
239
  sns.set(rc={'figure.figsize': (13.7, 8.27)})
240
  sns.set_theme(style="whitegrid", color_codes=True)
 
250
  ax.grid(b=True, which='minor', color='whitesmoke', linewidth=0.5)
251
 
252
  # Apply custom color settings to y-axis labels
253
+ for label in ax.get_yticklabels():
254
+ method = label.get_text()
255
+ label.set_color(get_method_color(method))
256
 
257
  # Ensure save path exists
258
  os.makedirs(save_path, exist_ok=True)
259
 
260
  # Save the plot
261
+ save_path = "/tmp"
262
  filename = os.path.join(save_path, f"{metric}_affinity_results.png")
263
  ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight')
264
  plt.close() # Close the plot to free memory