File size: 3,062 Bytes
e159d95
5264831
 
 
 
 
 
 
 
e159d95
 
5264831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db9289d
5264831
 
 
 
 
db9289d
 
 
 
5264831
 
 
 
db9289d
5264831
db9289d
 
5264831
db9289d
5264831
 
 
 
 
 
 
 
 
db9289d
5264831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import plotly.express as px
import plotly.graph_objs as go
from collections import defaultdict
import json, math, gdown
import numpy as np
import pandas as pd
from Config import *
pd.options.display.float_format = '{:.2f}'.format


battles = np.linspace(0, 100, 100)

meta_topics = ['mmlu']

def generate_plot(meta_index, topic_index):  
    """
        Bar plot of a specific dataset
    """
    # battles = np.linspace(0, 100, 100)
    meta_topic = meta_topics[meta_index]
    print(meta_topic)
    topic = TOPICS[meta_topic][topic_index]

    data = pd.read_csv(f"data/{meta_topic}/response_rec.csv", sep=",")

    topic_data = data.loc[data['sub_topic'] == topic].copy()

    # Compute human and llm accuracy
    topic_data['human_acc'] = topic_data['no_correct_human'] / topic_data['no_responses_human'].replace(0, np.nan)
    topic_data['llm_acc'] = topic_data['no_correct_llm'] / topic_data['no_responses_llm'].replace(0, np.nan)

    # Selecting only numeric columns for aggregation
    numeric_cols = ['no_responses_human', 'no_correct_human', 'no_responses_llm', 'no_correct_llm', 'oracle_acc', 'human_acc', 'llm_acc']
    mean_data = topic_data.groupby('model_name')[numeric_cols].mean().reset_index()
    std_deviation = topic_data.groupby('model_name')[numeric_cols].std().reset_index()

    # Prepare the plot data
    plot_data = []

    # Define a consistent color scheme with different opacities
    colors = ['#FFA07A', '#20B2AA', '#778899']  # Light Salmon, Light Sea Green, Light Slate Gray
    acc_types = ['oracle_acc', 'human_acc', 'llm_acc']

    # Add bars with error bars for the averages
    for acc_type, color in zip(acc_types, colors):
        plot_data.append(go.Bar(
            x=mean_data['model_name'],
            y=mean_data[acc_type],
            error_y=dict(
                type='data',
                array=std_deviation[acc_type],
                visible=True
            ),
            name=acc_type.split('_')[0].capitalize(),
            marker=dict(color=color)
        ))

    # Layout
    layout = go.Layout(
        title=f"Accuracy for {meta_topic} ({topic})",
        xaxis=dict(title='Model Name'),
        yaxis=dict(title='Accuracy'),
        showlegend=True,
        legend=dict(title='Accuracy Type'),
        barmode='group'
    )

    fig = go.Figure(data=plot_data, layout=layout)
    return fig


# Gradio interface with grid layout
with gr.Blocks() as interface:
    with gr.Row():  # Row 1
        plot1 = gr.Plot(generate_plot(0, 0))  
        # plot1.update(inputs=[0, 0])  
        plot2 = gr.Plot(generate_plot(0, 0))
        # plot2.update(inputs=[0, 1]) 
    with gr.Row():  # Row 2
        plot3 = gr.Plot(generate_plot(0, 0)) 
        # plot3.update(inputs=[1, 0]) 
        plot4 = gr.Plot(generate_plot(0, 0))
        # plot4.update(inputs=[1, 1])  
    with gr.Row():  # Row 3
        plot5 = gr.Plot(generate_plot(0, 0))  
        # plot5.update(inputs=[2, 0])  
        plot6 = gr.Plot(generate_plot(0, 0))
        # plot6.update(inputs=[2, 1]) 

interface.launch()