Spaces:
Running
Running
import gradio as gr | |
from functools import lru_cache | |
import random | |
import requests | |
import logging | |
import arena_config | |
import plotly.graph_objects as go | |
from typing import Dict | |
from leaderboard import ( | |
get_current_leaderboard, | |
update_leaderboard, | |
start_backup_thread, | |
get_leaderboard, | |
get_elo_leaderboard, | |
ensure_elo_ratings_initialized | |
) | |
import sys | |
from fun_stats import get_fun_stats | |
import threading | |
import time | |
from collections import Counter | |
from model_suggestions import add_suggestion, get_suggestions_html | |
# Initialize logging for errors only | |
logging.basicConfig(level=logging.ERROR) | |
logger = logging.getLogger(__name__) | |
# Start the backup thread | |
start_backup_thread() | |
# Function to get available models (using predefined list) | |
def get_available_models(): | |
return [model[0] for model in arena_config.APPROVED_MODELS] | |
# Function to call Ollama API with caching | |
def call_ollama_api(model, prompt): | |
payload = { | |
"model": model, | |
"messages": [ | |
{ | |
"role": "system", | |
"content": "You are a helpful assistant. At no point should you reveal your name, identity or team affiliation to the user, especially if asked directly!" | |
}, | |
{ | |
"role": "user", | |
"content": prompt | |
} | |
] | |
} | |
try: | |
response = requests.post( | |
f"{arena_config.API_URL}/v1/chat/completions", | |
headers=arena_config.HEADERS, | |
json=payload, | |
timeout=100 | |
) | |
response.raise_for_status() | |
data = response.json() | |
return data["choices"][0]["message"]["content"] | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Error calling Ollama API for model {model}: {e}") | |
return f"Error: Unable to get response from the model." | |
# Generate responses using two randomly selected models | |
def get_battle_counts(): | |
leaderboard = get_current_leaderboard() | |
battle_counts = Counter() | |
for model, data in leaderboard.items(): | |
battle_counts[model] = data['wins'] + data['losses'] | |
return battle_counts | |
def generate_responses(prompt): | |
available_models = get_available_models() | |
if len(available_models) < 2: | |
return "Error: Not enough models available", "Error: Not enough models available", None, None | |
battle_counts = get_battle_counts() | |
# Sort models by battle count (ascending) | |
sorted_models = sorted(available_models, key=lambda m: battle_counts.get(m, 0)) | |
# Select the first model (least battles) | |
model_a = sorted_models[0] | |
# For the second model, use weighted random selection | |
weights = [1 / (battle_counts.get(m, 1) + 1) for m in sorted_models[1:]] | |
model_b = random.choices(sorted_models[1:], weights=weights, k=1)[0] | |
model_a_response = call_ollama_api(model_a, prompt) | |
model_b_response = call_ollama_api(model_b, prompt) | |
return model_a_response, model_b_response, model_a, model_b | |
def battle_arena(prompt): | |
response_a, response_b, model_a, model_b = generate_responses(prompt) | |
# Check for API errors in responses | |
if "Error: Unable to get response from the model" in response_a or "Error: Unable to get response from the model" in response_b: | |
return ( | |
[], [], None, None, | |
gr.update(value=[]), | |
gr.update(value=[]), | |
gr.update(interactive=False, value="Voting Disabled - API Error"), | |
gr.update(interactive=False, value="Voting Disabled - API Error"), | |
gr.update(interactive=False, visible=False), | |
prompt, | |
0, | |
gr.update(visible=False), | |
gr.update(value="Error: Unable to get response from the model", visible=True) | |
) | |
nickname_a = random.choice(arena_config.model_nicknames) | |
nickname_b = random.choice(arena_config.model_nicknames) | |
# Format responses for gr.Chatbot, including the user's prompt | |
response_a_formatted = [ | |
{"role": "user", "content": prompt}, | |
{"role": "assistant", "content": response_a} | |
] | |
response_b_formatted = [ | |
{"role": "user", "content": prompt}, | |
{"role": "assistant", "content": response_b} | |
] | |
if random.choice([True, False]): | |
return ( | |
response_a_formatted, response_b_formatted, model_a, model_b, | |
gr.update(label=nickname_a, value=response_a_formatted), | |
gr.update(label=nickname_b, value=response_b_formatted), | |
gr.update(interactive=True, value=f"Vote for {nickname_a}"), | |
gr.update(interactive=True, value=f"Vote for {nickname_b}"), | |
gr.update(interactive=True, visible=True), | |
prompt, | |
0, | |
gr.update(visible=False), | |
gr.update(value="Ready for your vote! π³οΈ", visible=True) | |
) | |
else: | |
return ( | |
response_b_formatted, response_a_formatted, model_b, model_a, | |
gr.update(label=nickname_a, value=response_b_formatted), | |
gr.update(label=nickname_b, value=response_a_formatted), | |
gr.update(interactive=True, value=f"Vote for {nickname_a}"), | |
gr.update(interactive=True, value=f"Vote for {nickname_b}"), | |
gr.update(interactive=True, visible=True), | |
prompt, | |
0, | |
gr.update(visible=False), | |
gr.update(value="Ready for your vote! π³οΈ", visible=True) | |
) | |
def record_vote(prompt, left_response, right_response, left_model, right_model, choice): | |
# Check if outputs are generated | |
if not left_response or not right_response or not left_model or not right_model: | |
return ( | |
"Please generate responses before voting.", | |
gr.update(), | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
gr.update(visible=False), | |
gr.update() | |
) | |
winner = left_model if choice == "Left is better" else right_model | |
loser = right_model if choice == "Left is better" else left_model | |
# Update the leaderboard | |
battle_results = update_leaderboard(winner, loser) | |
result_message = f""" | |
π Vote recorded! You're awesome! π | |
π΅ In the left corner: {get_human_readable_name(left_model)} | |
π΄ In the right corner: {get_human_readable_name(right_model)} | |
π And the champion you picked is... {get_human_readable_name(winner)}! π₯ | |
""" | |
return ( | |
gr.update(value=result_message, visible=True), # Show result as Markdown | |
get_leaderboard(), # Update leaderboard | |
get_elo_leaderboard(), # Add this line | |
gr.update(interactive=False), # Disable left vote button | |
gr.update(interactive=False), # Disable right vote button | |
gr.update(interactive=False), # Disable tie button | |
gr.update(visible=True), # Show model names | |
get_leaderboard_chart() # Update leaderboard chart | |
) | |
def get_leaderboard_chart(): | |
battle_results = get_current_leaderboard() | |
# Calculate scores and sort results | |
for model, results in battle_results.items(): | |
total_battles = results["wins"] + results["losses"] | |
if total_battles > 0: | |
win_rate = results["wins"] / total_battles | |
results["score"] = win_rate * (1 - 1 / (total_battles + 1)) | |
else: | |
results["score"] = 0 | |
sorted_results = sorted( | |
battle_results.items(), | |
key=lambda x: (x[1]["score"], x[1]["wins"] + x[1]["losses"]), | |
reverse=True | |
) | |
models = [get_human_readable_name(model) for model, _ in sorted_results] | |
wins = [results["wins"] for _, results in sorted_results] | |
losses = [results["losses"] for _, results in sorted_results] | |
scores = [results["score"] for _, results in sorted_results] | |
fig = go.Figure() | |
# Stacked Bar chart for Wins and Losses | |
fig.add_trace(go.Bar( | |
x=models, | |
y=wins, | |
name='Wins', | |
marker_color='#22577a' | |
)) | |
fig.add_trace(go.Bar( | |
x=models, | |
y=losses, | |
name='Losses', | |
marker_color='#38a3a5' | |
)) | |
# Line chart for Scores | |
fig.add_trace(go.Scatter( | |
x=models, | |
y=scores, | |
name='Score', | |
yaxis='y2', | |
line=dict(color='#ff7f0e', width=2) | |
)) | |
# Update layout for full-width, increased height, and secondary y-axis | |
fig.update_layout( | |
title='Model Performance', | |
xaxis_title='Models', | |
yaxis_title='Number of Battles', | |
yaxis2=dict( | |
title='Score', | |
overlaying='y', | |
side='right' | |
), | |
barmode='stack', | |
height=800, | |
width=1450, | |
autosize=True, | |
legend=dict( | |
orientation='h', | |
yanchor='bottom', | |
y=1.02, | |
xanchor='right', | |
x=1 | |
) | |
) | |
chart_data = fig.to_json() | |
return fig | |
def new_battle(): | |
nickname_a = random.choice(arena_config.model_nicknames) | |
nickname_b = random.choice(arena_config.model_nicknames) | |
return ( | |
"", # Reset prompt_input | |
gr.update(value=[], label=nickname_a), # Reset left Chatbot | |
gr.update(value=[], label=nickname_b), # Reset right Chatbot | |
None, | |
None, | |
gr.update(interactive=False, value=f"Vote for {nickname_a}"), | |
gr.update(interactive=False, value=f"Vote for {nickname_b}"), | |
gr.update(interactive=False, visible=False), # Reset Tie button | |
gr.update(value="", visible=False), | |
gr.update(), | |
gr.update(visible=False), | |
gr.update(), | |
0 # Reset tie_count | |
) | |
# Add this new function | |
def get_human_readable_name(model_name: str) -> str: | |
model_dict = dict(arena_config.APPROVED_MODELS) | |
return model_dict.get(model_name, model_name) | |
# Add this new function to randomly select a prompt | |
def random_prompt(): | |
return random.choice(arena_config.example_prompts) | |
# Modify the continue_conversation function | |
def continue_conversation(prompt, left_chat, right_chat, left_model, right_model, previous_prompt, tie_count): | |
# Check if the prompt is empty or the same as the previous one | |
if not prompt or prompt == previous_prompt: | |
prompt = random.choice(arena_config.example_prompts) | |
left_response = call_ollama_api(left_model, prompt) | |
right_response = call_ollama_api(right_model, prompt) | |
left_chat.append({"role": "user", "content": prompt}) | |
left_chat.append({"role": "assistant", "content": left_response}) | |
right_chat.append({"role": "user", "content": prompt}) | |
right_chat.append({"role": "assistant", "content": right_response}) | |
tie_count += 1 | |
tie_button_state = gr.update(interactive=True) if tie_count < 3 else gr.update(interactive=False, value="Max ties reached. Please vote!") | |
return ( | |
gr.update(value=left_chat), | |
gr.update(value=right_chat), | |
gr.update(value=""), # Clear the prompt input | |
tie_button_state, | |
prompt, # Return the new prompt | |
tie_count | |
) | |
def get_fun_stats_html(): | |
stats = get_fun_stats() | |
html = f""" | |
<style> | |
.fun-stats {{ | |
font-family: Arial, sans-serif; | |
font-size: 18px; | |
line-height: 1.6; | |
max-width: 800px; | |
margin: 0 auto; | |
padding: 20px; | |
}} | |
.fun-stats h2 {{ | |
font-size: 36px; | |
color: inherit; | |
text-align: center; | |
margin-bottom: 20px; | |
}} | |
.fun-stats h3 {{ | |
font-size: 28px; | |
color: inherit; | |
margin-top: 30px; | |
margin-bottom: 15px; | |
border-bottom: 2px solid currentColor; | |
padding-bottom: 10px; | |
}} | |
.fun-stats ul {{ | |
list-style-type: none; | |
padding-left: 0; | |
}} | |
.fun-stats li {{ | |
margin-bottom: 15px; | |
padding: 15px; | |
border-radius: 5px; | |
box-shadow: 0 2px 5px rgba(0,0,0,0.1); | |
}} | |
.fun-stats .timestamp {{ | |
font-style: italic; | |
text-align: center; | |
margin-bottom: 20px; | |
}} | |
.fun-stats .highlight {{ | |
font-weight: bold; | |
color: #e74c3c; | |
}} | |
</style> | |
<div class="fun-stats"> | |
<h2>π Fun Arena Stats π</h2> | |
<p class="timestamp">Last updated: {stats['timestamp']}</p> | |
<h3>ποΈ Arena Overview</h3> | |
<p>Total Battles Fought: <span class="highlight">{stats['total_battles']}</span></p> | |
<p>Active Gladiators (Models): <span class="highlight">{stats['active_models']}</span></p> | |
<h3>π Hall of Fame</h3> | |
<p>π₯ Battle Veteran: <span class="highlight">{stats['most_battles']['model']}</span> ({stats['most_battles']['battles']} battles)</p> | |
<p>πΉ Sharpshooter: <span class="highlight">{stats['highest_win_rate']['model']}</span> (Win Rate: {stats['highest_win_rate']['win_rate']})</p> | |
<p>π Jack of All Trades: <span class="highlight">{stats['most_diverse_opponent']['model']}</span> (Faced {stats['most_diverse_opponent']['unique_opponents']} unique opponents)</p> | |
<p>π Underdog Champion: <span class="highlight">{stats['underdog_champion']['model']}</span> ({stats['underdog_champion']['size']} model with {stats['underdog_champion']['win_rate']} win rate)</p> | |
<p>βοΈ Mr. Consistent: <span class="highlight">{stats['most_consistent']['model']}</span> (Closest to 50% win rate, difference of {stats['most_consistent']['win_loss_difference']} wins/losses)</p> | |
<h3>π€Ό Epic Battles</h3> | |
<p>π€Ό Biggest Rivalry: <span class="highlight">{stats['biggest_rivalry']['model1']}</span> vs <span class="highlight">{stats['biggest_rivalry']['model2']}</span> ({stats['biggest_rivalry']['total_battles']} fierce battles!)</p> | |
<p>ποΈ David vs Goliath: <span class="highlight">{stats['david_vs_goliath']['david']}</span> (David) vs <span class="highlight">{stats['david_vs_goliath']['goliath']}</span> (Goliath)<br> | |
David won {stats['david_vs_goliath']['wins']} times despite being {stats['david_vs_goliath']['size_difference']} smaller!</p> | |
<p>π Comeback King: <span class="highlight">{stats['comeback_king']['model']}</span> (Overcame a {stats['comeback_king']['comeback_margin']}-battle deficit)</p> | |
<p>π Pyrrhic Victor: <span class="highlight">{stats['pyrrhic_victor']['model']}</span> (Lowest win rate among models with more wins than losses: {stats['pyrrhic_victor']['win_rate']})</p> | |
</div> | |
""" | |
return html | |
def update_fun_stats_periodically(interval): | |
while True: | |
time.sleep(interval) | |
fun_stats_html.update(value=get_fun_stats_html()) | |
# Initialize Gradio Blocks | |
with gr.Blocks(css=""" | |
#dice-button { | |
min-height: 90px; | |
font-size: 35px; | |
} | |
""") as demo: | |
gr.Markdown(arena_config.ARENA_NAME) | |
gr.Markdown(arena_config.ARENA_DESCRIPTION) | |
# Leaderboard Tab (now first) | |
with gr.Tab("Leaderboard"): | |
leaderboard = gr.HTML(label="Leaderboard") | |
# Battle Arena Tab (now second) | |
with gr.Tab("Battle Arena"): | |
with gr.Row(): | |
prompt_input = gr.Textbox( | |
label="Enter your prompt", | |
placeholder="Type your prompt here...", | |
scale=20 | |
) | |
random_prompt_btn = gr.Button("π²", scale=1, elem_id="dice-button") | |
gr.Markdown("<br>") | |
# Add the random prompt button functionality | |
random_prompt_btn.click( | |
random_prompt, | |
outputs=prompt_input | |
) | |
submit_btn = gr.Button("Generate Responses", variant="primary") | |
with gr.Row(): | |
left_output = gr.Chatbot(label=random.choice(arena_config.model_nicknames), type="messages") | |
right_output = gr.Chatbot(label=random.choice(arena_config.model_nicknames), type="messages") | |
with gr.Row(): | |
left_vote_btn = gr.Button(f"Vote for {left_output.label}", interactive=False) | |
tie_btn = gr.Button("Tie π Continue with a new prompt", interactive=False, visible=False) | |
right_vote_btn = gr.Button(f"Vote for {right_output.label}", interactive=False) | |
result = gr.Textbox( | |
label="Status", | |
interactive=False, | |
value="Generate responses to start the battle! π", | |
visible=True # Always visible | |
) | |
with gr.Row(visible=False) as model_names_row: | |
left_model = gr.Textbox(label="π΅ Left Model", interactive=False) | |
right_model = gr.Textbox(label="π΄ Right Model", interactive=False) | |
previous_prompt = gr.State("") # Add this line to store the previous prompt | |
tie_count = gr.State(0) # Add this line to keep track of tie count | |
new_battle_btn = gr.Button("New Battle") | |
# Performance Chart Tab | |
with gr.Tab("Performance Chart"): | |
leaderboard_chart = gr.Plot(label="Model Performance Chart") | |
# ELO Leaderboard Tab | |
with gr.Tab("ELO Leaderboard"): | |
elo_leaderboard = gr.HTML(label="ELO Leaderboard") | |
# Add this new tab | |
with gr.Tab("Fun Stats"): | |
refresh_btn = gr.Button("Refresh Stats") | |
fun_stats_html = gr.HTML(label="Fun Arena Stats") | |
# Add this new tab | |
with gr.Tab("Suggest Models"): | |
with gr.Row(): | |
model_url_input = gr.Textbox( | |
label="Model URL", | |
placeholder="hf.co/username/model-name-GGUF:Q4_K_M", | |
scale=4 | |
) | |
submit_suggestion_btn = gr.Button("Submit Suggestion", scale=1, variant="primary") | |
suggestion_status = gr.Markdown("Submit a model to see it listed below!") | |
suggestions_list = gr.HTML(get_suggestions_html()) | |
refresh_suggestions_btn = gr.Button("Refresh List") | |
# Update button click handlers | |
submit_suggestion_btn.click( | |
add_suggestion, | |
inputs=[model_url_input], | |
outputs=[suggestion_status] | |
).then( | |
lambda: ( | |
get_suggestions_html(), # Update suggestions list | |
"" # Clear model URL input | |
), | |
outputs=[ | |
suggestions_list, | |
model_url_input | |
] | |
) | |
refresh_suggestions_btn.click( | |
get_suggestions_html, | |
outputs=[suggestions_list] | |
) | |
# Define interactions | |
submit_btn.click( | |
battle_arena, | |
inputs=prompt_input, | |
outputs=[ | |
left_output, right_output, left_model, right_model, | |
left_output, right_output, left_vote_btn, right_vote_btn, | |
tie_btn, previous_prompt, tie_count, model_names_row, result | |
] | |
) | |
left_vote_btn.click( | |
lambda *args: record_vote(*args, "Left is better"), | |
inputs=[prompt_input, left_output, right_output, left_model, right_model], | |
outputs=[result, leaderboard, elo_leaderboard, left_vote_btn, | |
right_vote_btn, tie_btn, model_names_row, leaderboard_chart] | |
) | |
right_vote_btn.click( | |
lambda *args: record_vote(*args, "Right is better"), | |
inputs=[prompt_input, left_output, right_output, left_model, right_model], | |
outputs=[result, leaderboard, elo_leaderboard, left_vote_btn, | |
right_vote_btn, tie_btn, model_names_row, leaderboard_chart] | |
) | |
tie_btn.click( | |
continue_conversation, | |
inputs=[prompt_input, left_output, right_output, left_model, right_model, previous_prompt, tie_count], | |
outputs=[left_output, right_output, prompt_input, tie_btn, previous_prompt, tie_count] | |
) | |
new_battle_btn.click( | |
new_battle, | |
outputs=[prompt_input, left_output, right_output, left_model, | |
right_model, left_vote_btn, right_vote_btn, tie_btn, | |
result, leaderboard, model_names_row, leaderboard_chart, tie_count] | |
) | |
# Update leaderboard and chart on launch | |
demo.load(get_leaderboard, outputs=leaderboard) | |
demo.load(get_elo_leaderboard, outputs=elo_leaderboard) | |
demo.load(get_leaderboard_chart, outputs=leaderboard_chart) | |
demo.load(get_fun_stats_html, outputs=fun_stats_html) | |
# Add this event handler for the refresh button | |
refresh_btn.click(get_fun_stats_html, outputs=fun_stats_html) | |
# Start the background task to update stats every hour | |
update_thread = threading.Thread(target=update_fun_stats_periodically, args=(3600,), daemon=True) | |
update_thread.start() | |
if __name__ == "__main__": | |
# Initialize ELO ratings before launching the app | |
ensure_elo_ratings_initialized() | |
demo.launch(show_api=False) | |