ramiroluo's picture
Add Application
91cf45c
raw
history blame
No virus
5.39 kB
import json
import pandas as pd
import streamlit as st
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def display_results(results, setting='avg', rank_metric='Entailment(↓)', is_auto=True):
label_marker = {
'Entailment': dict(
color='rgba(102, 204, 0, 0.6)'),
'Neutral': dict(
color='rgba(255, 178, 102, 0.6)'),
'Contradiction': dict(
color='rgba(255, 51, 51, 0.6)'),
'Abstain': dict(
color='rgba(192, 192, 192, 0.6)')
}
model_names= []
entails = []
neutrals = []
contras = []
abstains = []
for k, v in results.items():
model_names.append(k)
entails.append(v[setting]['entailment'])
neutrals.append(v[setting]['neutral'])
contras.append(v[setting]['contradiction'])
abstains.append(v[setting]['abstain'])
results = list(zip(model_names, entails, neutrals, contras, abstains))
label_order = None
if rank_metric == 'Entailment(↓)':
results = sorted(results, key=lambda x: x[1])
label_order = ['Entailment', 'Neutral', 'Contradiction']
elif rank_metric == 'Neutral(↑)':
results = sorted(results, key=lambda x: x[2], reverse=True)
label_order = ['Neutral', 'Contradiction', 'Entailment']
elif rank_metric == 'Contradiction(↑)':
results = sorted(results, key=lambda x: x[3], reverse=True)
label_order = ['Contradiction', 'Neutral', 'Entailment']
elif rank_metric == 'Abstain(↑)':
results = sorted(results, key=lambda x: x[4], reverse=True)
label_order = ['Contradiction', 'Neutral', 'Entailment']
label_to_results_idx = {
'Entailment': 1,
'Neutral': 2,
'Contradiction': 3,
'Abstain': 4
}
# fig = go.Figure()
fig = make_subplots(rows=1, cols=2, shared_yaxes=True, column_widths=[0.9, 0.1], horizontal_spacing=0)
for label in label_order:
text = []
fig.add_trace(
go.Bar(
y=[x[0] for rank, x in enumerate(results)],
x=[x[label_to_results_idx[label]] for x in results],
name=label,
orientation='h',
marker=label_marker[label],
text=[round(x[label_to_results_idx[label]], 1) for x in results]
),
row=1,
col=1
)
# abstain bar
fig.add_trace(
go.Bar(
y=[x[0] for rank, x in enumerate(results)],
x=[x[label_to_results_idx['Abstain']] for x in results],
name='Abstain',
orientation='h',
marker=label_marker['Abstain'],
text=[round(x[label_to_results_idx['Abstain']], 1) for x in results]
),
row=1,
col=2
)
fig.update_layout(
barmode='stack',
width=1000,
height=900 if is_auto else 500,
bargap=0.35,
legend_font=dict(size=18),
)
fig.update_yaxes(tickfont=dict(size=19, color='black'))
st.plotly_chart(fig)
if __name__ == '__main__':
st.set_page_config(layout='wide')
st.title('HalluChecker Leaderboard')
st.write('[GitHub repo of HalluChecker](https://github.com/LuoXiaoHeics/HalluChecker)')
tab1 = st.tabs(['Auto-checked Leaderboard'])
with tab1:
col1, col2 = st.columns([1, 7])
with col1:
extractor = st.radio('Claim-Triplet Extractor', ['GPT-4', 'Claude 2'])
checker = st.radio('Checker', ['Ensemble of 3 Checkers', 'GPT-4', 'Claude 2', 'RoBERTa-NLI'])
model_map = {
'Ensemble of 3 Checkers': 'ensemble',
'GPT-4': 'gpt4',
'Claude 2': 'claude2',
'RoBERTa-NLI': 'nli'
}
extractor = model_map[extractor]
checker = model_map[checker]
rank_metric = st.radio('Rank By:', ['Contradiction(↑)', 'Neutral(↑)', 'Entailment(↓)', 'Abstain(↑)'])
with col2:
results = json.load(open('auto_leaderboard_scores.json'))
res_key = f'{extractor}###{checker}'
if res_key not in results:
st.write('Work in progress, please stay tuned 😊')
else:
results = results[res_key]
tab_avg, tab_zero, tab_noisy, tab_accurate = \
st.tabs(['Average over Settings', 'Zero Context', 'Noisy Context', 'Accurate Context'])
with tab_avg:
display_results(results, setting='avg', rank_metric=rank_metric)
with tab_zero:
display_results(results, setting='nq', rank_metric=rank_metric)
with tab_noisy:
display_results(results, setting='msmarco', rank_metric=rank_metric)
with tab_accurate:
display_results(results, setting='dolly', rank_metric=rank_metric)
st.divider()
st.write('\* The responses of Gemini Pro (Bard) are manually collected from [Google Bard](https://bard.google.com/) on December 7, 2023.')
st.write('† The responses of Gemini Pro (API) are collected from its offical API without tools.')
st.write('♣ Our project is executed using the tool of RefChecker (https://github.com/amazon-science/RefChecker).')