Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import pandas as pd
|
3 |
+
import gradio as gr
|
4 |
+
import litellm
|
5 |
+
import plotly.express as px
|
6 |
+
from collections import defaultdict
|
7 |
+
from datetime import datetime
|
8 |
+
|
9 |
+
def preprocess_dataset(test_data):
|
10 |
+
"""
|
11 |
+
Preprocess the dataset to convert the 'choices' field from a string to a list of strings.
|
12 |
+
"""
|
13 |
+
preprocessed_data = []
|
14 |
+
for example in test_data:
|
15 |
+
if isinstance(example['choices'], str):
|
16 |
+
choices_str = example['choices']
|
17 |
+
if choices_str.startswith("'") and choices_str.endswith("'"):
|
18 |
+
choices_str = choices_str[1:-1]
|
19 |
+
elif choices_str.startswith('"') and choices_str.endswith('"'):
|
20 |
+
choices_str = choices_str[1:-1]
|
21 |
+
choices_str = choices_str.replace("\\'", "'")
|
22 |
+
try:
|
23 |
+
example['choices'] = ast.literal_eval(choices_str)
|
24 |
+
except (ValueError, SyntaxError):
|
25 |
+
print(f"Error parsing choices: {choices_str}")
|
26 |
+
continue
|
27 |
+
preprocessed_data.append(example)
|
28 |
+
return preprocessed_data
|
29 |
+
|
30 |
+
def evaluate_afrimmlu(test_data, model_name="deepseek-chat"):
|
31 |
+
"""
|
32 |
+
Evaluate the model on the AfriMMLU dataset.
|
33 |
+
"""
|
34 |
+
results = []
|
35 |
+
correct = 0
|
36 |
+
total = 0
|
37 |
+
subject_results = defaultdict(lambda: {"correct": 0, "total": 0})
|
38 |
+
|
39 |
+
for example in test_data:
|
40 |
+
question = example['question']
|
41 |
+
choices = example['choices']
|
42 |
+
answer = example['answer']
|
43 |
+
subject = example['subject']
|
44 |
+
|
45 |
+
prompt = (
|
46 |
+
f"Answer the following multiple-choice question. "
|
47 |
+
f"Return only the letter corresponding to the correct answer (A, B, C, or D).\n"
|
48 |
+
f"Question: {question}\n"
|
49 |
+
f"Options:\n"
|
50 |
+
f"A. {choices[0]}\n"
|
51 |
+
f"B. {choices[1]}\n"
|
52 |
+
f"C. {choices[2]}\n"
|
53 |
+
f"D. {choices[3]}\n"
|
54 |
+
f"Answer:"
|
55 |
+
)
|
56 |
+
|
57 |
+
try:
|
58 |
+
response = litellm.completion(
|
59 |
+
model=model_name,
|
60 |
+
messages=[{"role": "user", "content": prompt}]
|
61 |
+
)
|
62 |
+
model_output = response.choices[0].message.content.strip().upper()
|
63 |
+
|
64 |
+
model_answer = None
|
65 |
+
for char in model_output:
|
66 |
+
if char in ['A', 'B', 'C', 'D']:
|
67 |
+
model_answer = char
|
68 |
+
break
|
69 |
+
|
70 |
+
is_correct = model_answer == answer.upper()
|
71 |
+
if is_correct:
|
72 |
+
correct += 1
|
73 |
+
subject_results[subject]["correct"] += 1
|
74 |
+
total += 1
|
75 |
+
subject_results[subject]["total"] += 1
|
76 |
+
|
77 |
+
# Store detailed results
|
78 |
+
results.append({
|
79 |
+
'timestamp': datetime.now().isoformat(),
|
80 |
+
'subject': subject,
|
81 |
+
'question': question,
|
82 |
+
'model_answer': model_answer,
|
83 |
+
'correct_answer': answer.upper(),
|
84 |
+
'is_correct': is_correct,
|
85 |
+
'total_tokens': response.usage.total_tokens
|
86 |
+
})
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error processing question: {str(e)}")
|
90 |
+
continue
|
91 |
+
|
92 |
+
# Calculate accuracies
|
93 |
+
accuracy = (correct / total * 100) if total > 0 else 0
|
94 |
+
subject_accuracy = {
|
95 |
+
subject: (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0
|
96 |
+
for subject, stats in subject_results.items()
|
97 |
+
}
|
98 |
+
|
99 |
+
# Export results to CSV
|
100 |
+
df = pd.DataFrame(results)
|
101 |
+
df.to_csv('detailed_results.csv', index=False)
|
102 |
+
|
103 |
+
# Export summary to CSV
|
104 |
+
summary_data = [{'subject': subject, 'accuracy': acc}
|
105 |
+
for subject, acc in subject_accuracy.items()]
|
106 |
+
summary_data.append({'subject': 'Overall', 'accuracy': accuracy})
|
107 |
+
pd.DataFrame(summary_data).to_csv('summary_results.csv', index=False)
|
108 |
+
|
109 |
+
return {
|
110 |
+
"accuracy": accuracy,
|
111 |
+
"subject_accuracy": subject_accuracy,
|
112 |
+
"detailed_results": results
|
113 |
+
}
|
114 |
+
|
115 |
+
def create_visualization(results_dict):
|
116 |
+
"""
|
117 |
+
Create visualization from evaluation results.
|
118 |
+
"""
|
119 |
+
summary_data = [
|
120 |
+
{'Subject': subject, 'Accuracy (%)': accuracy}
|
121 |
+
for subject, accuracy in results_dict['subject_accuracy'].items()
|
122 |
+
]
|
123 |
+
summary_data.append({'Subject': 'Overall', 'Accuracy (%)': results_dict['accuracy']})
|
124 |
+
summary_df = pd.DataFrame(summary_data)
|
125 |
+
|
126 |
+
fig = px.bar(
|
127 |
+
summary_df,
|
128 |
+
x='Subject',
|
129 |
+
y='Accuracy (%)',
|
130 |
+
title='AfriMMLU Evaluation Results',
|
131 |
+
labels={'Subject': 'Subject', 'Accuracy (%)': 'Accuracy (%)'}
|
132 |
+
)
|
133 |
+
fig.update_layout(
|
134 |
+
xaxis_tickangle=-45,
|
135 |
+
showlegend=False,
|
136 |
+
height=600
|
137 |
+
)
|
138 |
+
|
139 |
+
return summary_df, fig
|
140 |
+
|
141 |
+
def evaluate_and_display(test_file, model_name):
|
142 |
+
"""
|
143 |
+
Process uploaded file and run evaluation.
|
144 |
+
"""
|
145 |
+
test_data = pd.read_json(test_file.name)
|
146 |
+
preprocessed_data = preprocess_dataset(test_data.to_dict('records'))
|
147 |
+
|
148 |
+
results = evaluate_afrimmlu(preprocessed_data, model_name)
|
149 |
+
|
150 |
+
summary_df, plot = create_visualization(results)
|
151 |
+
detailed_df = pd.read_csv('detailed_results.csv')
|
152 |
+
|
153 |
+
return summary_df, plot, detailed_df
|
154 |
+
|
155 |
+
def create_gradio_interface():
|
156 |
+
"""
|
157 |
+
Create and configure the Gradio interface.
|
158 |
+
"""
|
159 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
160 |
+
gr.Markdown("""
|
161 |
+
# AfriMMLU Evaluation Dashboard
|
162 |
+
Upload your test data and select a model to evaluate performance on the AfriMMLU benchmark.
|
163 |
+
""")
|
164 |
+
|
165 |
+
with gr.Row():
|
166 |
+
with gr.Column(scale=1):
|
167 |
+
file_input = gr.File(
|
168 |
+
label="Upload Test Data (JSON)",
|
169 |
+
file_types=[".json"]
|
170 |
+
)
|
171 |
+
model_input = gr.Dropdown(
|
172 |
+
choices=["deepseek-chat", "gpt-3.5-turbo", "gpt-4"],
|
173 |
+
label="Select Model",
|
174 |
+
value="deepseek-chat"
|
175 |
+
)
|
176 |
+
evaluate_btn = gr.Button("Evaluate", variant="primary")
|
177 |
+
|
178 |
+
with gr.Row():
|
179 |
+
with gr.Column():
|
180 |
+
summary_table = gr.Dataframe(
|
181 |
+
headers=["Subject", "Accuracy (%)"],
|
182 |
+
label="Summary Results"
|
183 |
+
)
|
184 |
+
|
185 |
+
with gr.Row():
|
186 |
+
with gr.Column():
|
187 |
+
summary_plot = gr.Plot(label="Performance by Subject")
|
188 |
+
|
189 |
+
with gr.Row():
|
190 |
+
with gr.Column():
|
191 |
+
detailed_results = gr.Dataframe(
|
192 |
+
label="Detailed Results",
|
193 |
+
wrap=True
|
194 |
+
)
|
195 |
+
|
196 |
+
evaluate_btn.click(
|
197 |
+
fn=evaluate_and_display,
|
198 |
+
inputs=[file_input, model_input],
|
199 |
+
outputs=[summary_table, summary_plot, detailed_results]
|
200 |
+
)
|
201 |
+
|
202 |
+
return demo
|
203 |
+
|
204 |
+
if __name__ == "__main__":
|
205 |
+
demo = create_gradio_interface()
|
206 |
+
demo.launch(share=True)
|