import gradio as gr import plotly.express as px import plotly.graph_objs as go from collections import defaultdict import json, math, gdown import numpy as np import pandas as pd from Config import * pd.options.display.float_format = '{:.2f}'.format battles = np.linspace(0, 100, 100) meta_topics = ['mmlu'] def generate_plot(meta_index, topic_index): """ Bar plot of a specific dataset """ # battles = np.linspace(0, 100, 100) meta_topic = meta_topics[meta_index] print(meta_topic) topic = TOPICS[meta_topic][topic_index] data = pd.read_csv(f"data/{meta_topic}/response_rec.csv", sep=",") topic_data = data.loc[data['sub_topic'] == topic].copy() # Compute human and llm accuracy topic_data['human_acc'] = topic_data['no_correct_human'] / topic_data['no_responses_human'].replace(0, np.nan) topic_data['llm_acc'] = topic_data['no_correct_llm'] / topic_data['no_responses_llm'].replace(0, np.nan) # Selecting only numeric columns for aggregation numeric_cols = ['no_responses_human', 'no_correct_human', 'no_responses_llm', 'no_correct_llm', 'oracle_acc', 'human_acc', 'llm_acc'] mean_data = topic_data.groupby('model_name')[numeric_cols].mean().reset_index() std_deviation = topic_data.groupby('model_name')[numeric_cols].std().reset_index() # Prepare the plot data plot_data = [] # Define a consistent color scheme with different opacities colors = ['#FFA07A', '#20B2AA', '#778899'] # Light Salmon, Light Sea Green, Light Slate Gray acc_types = ['oracle_acc', 'human_acc', 'llm_acc'] # Add bars with error bars for the averages for acc_type, color in zip(acc_types, colors): plot_data.append(go.Bar( x=mean_data['model_name'], y=mean_data[acc_type], error_y=dict( type='data', array=std_deviation[acc_type], visible=True ), name=acc_type.split('_')[0].capitalize(), marker=dict(color=color) )) # Layout layout = go.Layout( title=f"Accuracy for {meta_topic} ({topic})", xaxis=dict(title='Model Name'), yaxis=dict(title='Accuracy'), showlegend=True, legend=dict(title='Accuracy Type'), barmode='group' ) fig = go.Figure(data=plot_data, layout=layout) return fig # Gradio interface with grid layout with gr.Blocks() as interface: with gr.Row(): # Row 1 plot1 = gr.Plot(generate_plot(0, 0)) # plot1.update(inputs=[0, 0]) plot2 = gr.Plot(generate_plot(0, 0)) # plot2.update(inputs=[0, 1]) with gr.Row(): # Row 2 plot3 = gr.Plot(generate_plot(0, 0)) # plot3.update(inputs=[1, 0]) plot4 = gr.Plot(generate_plot(0, 0)) # plot4.update(inputs=[1, 1]) with gr.Row(): # Row 3 plot5 = gr.Plot(generate_plot(0, 0)) # plot5.update(inputs=[2, 0]) plot6 = gr.Plot(generate_plot(0, 0)) # plot6.update(inputs=[2, 1]) interface.launch()