Spaces:
Sleeping
Sleeping
Update src/vis_utils.py
Browse files- src/vis_utils.py +14 -7
src/vis_utils.py
CHANGED
@@ -183,6 +183,8 @@ def plot_family_results(methods_selected, dataset, metric, family_path="/tmp/fam
|
|
183 |
df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
|
184 |
df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index
|
185 |
|
|
|
|
|
186 |
# Set up the plot
|
187 |
sns.set(rc={'figure.figsize': (13.7, 18.27)})
|
188 |
sns.set_theme(style="whitegrid", color_codes=True)
|
@@ -214,23 +216,25 @@ def plot_family_results(methods_selected, dataset, metric, family_path="/tmp/fam
|
|
214 |
|
215 |
return filename
|
216 |
|
217 |
-
def plot_affinity_results(
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
220 |
|
221 |
# Filter for selected methods
|
222 |
df = df[df['Method'].isin(method_names)]
|
223 |
|
224 |
# Gather columns related to the specified metric and validate
|
225 |
metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")]
|
226 |
-
if not metric_columns:
|
227 |
-
print(f"No columns found for metric '{metric}'.")
|
228 |
-
return None
|
229 |
|
230 |
# Reshape data for plotting
|
231 |
df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
|
232 |
df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index for sorting
|
233 |
|
|
|
|
|
234 |
# Set up the plot
|
235 |
sns.set(rc={'figure.figsize': (13.7, 8.27)})
|
236 |
sns.set_theme(style="whitegrid", color_codes=True)
|
@@ -246,12 +250,15 @@ def plot_affinity_results(file_path, method_names, metric, save_path="./plot_ima
|
|
246 |
ax.grid(b=True, which='minor', color='whitesmoke', linewidth=0.5)
|
247 |
|
248 |
# Apply custom color settings to y-axis labels
|
249 |
-
|
|
|
|
|
250 |
|
251 |
# Ensure save path exists
|
252 |
os.makedirs(save_path, exist_ok=True)
|
253 |
|
254 |
# Save the plot
|
|
|
255 |
filename = os.path.join(save_path, f"{metric}_affinity_results.png")
|
256 |
ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight')
|
257 |
plt.close() # Close the plot to free memory
|
|
|
183 |
df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
|
184 |
df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index
|
185 |
|
186 |
+
df = df.fillna(0)
|
187 |
+
|
188 |
# Set up the plot
|
189 |
sns.set(rc={'figure.figsize': (13.7, 18.27)})
|
190 |
sns.set_theme(style="whitegrid", color_codes=True)
|
|
|
216 |
|
217 |
return filename
|
218 |
|
219 |
+
def plot_affinity_results(method_names, metric, affinity_path="/tmp/affinity_results.csv"):
|
220 |
+
if not os.path.exists(affinity_path):
|
221 |
+
benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later
|
222 |
+
download_from_hub(benchmark_types)
|
223 |
+
|
224 |
+
df = pd.read_csv(affinity_path)
|
225 |
|
226 |
# Filter for selected methods
|
227 |
df = df[df['Method'].isin(method_names)]
|
228 |
|
229 |
# Gather columns related to the specified metric and validate
|
230 |
metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")]
|
|
|
|
|
|
|
231 |
|
232 |
# Reshape data for plotting
|
233 |
df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
|
234 |
df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index for sorting
|
235 |
|
236 |
+
df = df.fillna(0)
|
237 |
+
|
238 |
# Set up the plot
|
239 |
sns.set(rc={'figure.figsize': (13.7, 8.27)})
|
240 |
sns.set_theme(style="whitegrid", color_codes=True)
|
|
|
250 |
ax.grid(b=True, which='minor', color='whitesmoke', linewidth=0.5)
|
251 |
|
252 |
# Apply custom color settings to y-axis labels
|
253 |
+
for label in ax.get_yticklabels():
|
254 |
+
method = label.get_text()
|
255 |
+
label.set_color(get_method_color(method))
|
256 |
|
257 |
# Ensure save path exists
|
258 |
os.makedirs(save_path, exist_ok=True)
|
259 |
|
260 |
# Save the plot
|
261 |
+
save_path = "/tmp"
|
262 |
filename = os.path.join(save_path, f"{metric}_affinity_results.png")
|
263 |
ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight')
|
264 |
plt.close() # Close the plot to free memory
|