Blair Yang
debug
db9289d
raw
history blame contribute delete
No virus
3.06 kB
import gradio as gr
import plotly.express as px
import plotly.graph_objs as go
from collections import defaultdict
import json, math, gdown
import numpy as np
import pandas as pd
from Config import *
pd.options.display.float_format = '{:.2f}'.format
battles = np.linspace(0, 100, 100)
meta_topics = ['mmlu']
def generate_plot(meta_index, topic_index):
"""
Bar plot of a specific dataset
"""
# battles = np.linspace(0, 100, 100)
meta_topic = meta_topics[meta_index]
print(meta_topic)
topic = TOPICS[meta_topic][topic_index]
data = pd.read_csv(f"data/{meta_topic}/response_rec.csv", sep=",")
topic_data = data.loc[data['sub_topic'] == topic].copy()
# Compute human and llm accuracy
topic_data['human_acc'] = topic_data['no_correct_human'] / topic_data['no_responses_human'].replace(0, np.nan)
topic_data['llm_acc'] = topic_data['no_correct_llm'] / topic_data['no_responses_llm'].replace(0, np.nan)
# Selecting only numeric columns for aggregation
numeric_cols = ['no_responses_human', 'no_correct_human', 'no_responses_llm', 'no_correct_llm', 'oracle_acc', 'human_acc', 'llm_acc']
mean_data = topic_data.groupby('model_name')[numeric_cols].mean().reset_index()
std_deviation = topic_data.groupby('model_name')[numeric_cols].std().reset_index()
# Prepare the plot data
plot_data = []
# Define a consistent color scheme with different opacities
colors = ['#FFA07A', '#20B2AA', '#778899'] # Light Salmon, Light Sea Green, Light Slate Gray
acc_types = ['oracle_acc', 'human_acc', 'llm_acc']
# Add bars with error bars for the averages
for acc_type, color in zip(acc_types, colors):
plot_data.append(go.Bar(
x=mean_data['model_name'],
y=mean_data[acc_type],
error_y=dict(
type='data',
array=std_deviation[acc_type],
visible=True
),
name=acc_type.split('_')[0].capitalize(),
marker=dict(color=color)
))
# Layout
layout = go.Layout(
title=f"Accuracy for {meta_topic} ({topic})",
xaxis=dict(title='Model Name'),
yaxis=dict(title='Accuracy'),
showlegend=True,
legend=dict(title='Accuracy Type'),
barmode='group'
)
fig = go.Figure(data=plot_data, layout=layout)
return fig
# Gradio interface with grid layout
with gr.Blocks() as interface:
with gr.Row(): # Row 1
plot1 = gr.Plot(generate_plot(0, 0))
# plot1.update(inputs=[0, 0])
plot2 = gr.Plot(generate_plot(0, 0))
# plot2.update(inputs=[0, 1])
with gr.Row(): # Row 2
plot3 = gr.Plot(generate_plot(0, 0))
# plot3.update(inputs=[1, 0])
plot4 = gr.Plot(generate_plot(0, 0))
# plot4.update(inputs=[1, 1])
with gr.Row(): # Row 3
plot5 = gr.Plot(generate_plot(0, 0))
# plot5.update(inputs=[2, 0])
plot6 = gr.Plot(generate_plot(0, 0))
# plot6.update(inputs=[2, 1])
interface.launch()