|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
from .utils import undo_hyperlink |
|
|
|
def plot_avg_correlation(df1, df2): |
|
""" |
|
Plots the "average" column for each unique model that appears in both dataframes. |
|
|
|
Parameters: |
|
- df1: pandas DataFrame containing columns "model" and "average". |
|
- df2: pandas DataFrame containing columns "model" and "average". |
|
""" |
|
|
|
common_models = pd.Series(list(set(df1['model']) & set(df2['model']))) |
|
|
|
|
|
plt.figure(figsize=(13, 6), constrained_layout=True) |
|
|
|
|
|
plt.xlim(0.475, 0.8) |
|
plt.ylim(0.475, 0.8) |
|
|
|
|
|
plt.rcParams.update({'font.size': 12, 'axes.labelsize': 14,'axes.titlesize': 14}) |
|
|
|
|
|
|
|
|
|
for model in common_models: |
|
|
|
df1_model_data = df1[df1['model'] == model]['average'].values |
|
df2_model_data = df2[df2['model'] == model]['average'].values |
|
|
|
|
|
plt.scatter(df1_model_data, df2_model_data, label=model) |
|
m_name = undo_hyperlink(model) |
|
if m_name == "No text found": |
|
m_name = "Random" |
|
|
|
|
|
plt.text(df1_model_data - .005, df2_model_data, m_name, horizontalalignment='right', verticalalignment='center') |
|
|
|
|
|
|
|
corr = df1['average'].corr(df2['average']) |
|
|
|
|
|
|
|
|
|
plt.xlabel('HERM Eval. Set Avg.', fontsize=16) |
|
plt.ylabel('Pref. Test Sets Avg.', fontsize=16) |
|
|
|
return plt |