K00B404's picture
Update app.py
8613c00 verified
raw
history blame
5.78 kB
import gradio as gr
from random import randint, sample
from all_models import models
import csv
import os
# Assuming you have a function to calculate ELO ratings
def init_model_scores(file_path='model_scores.csv'):
# Check if the CSV file exists, if not, create it with headers
if not os.path.isfile(csv_file_path):
with open(csv_file_path, 'w', newline='') as file:
writer = csv.writer(file)
writer.writerow(["Model Name", "Score"])
for model in models:
# make a entry for each model
writer.writerow([model, 0])
def update_elo_ratings(user_vote, csv_file_path='model_scores.csv'):
# Logic to update ELO ratings based on user vote
# Read the current scores from the CSV file
scores = {}
with open(csv_file_path, 'r') as file:
reader = csv.reader(file)
next(reader) # Skip the header row
for row in reader:
scores[row[0]] = int(row[1])
# Update the score for the selected model
if user_vote in scores:
scores[user_vote] += 1 # Increment the score
else:
scores[user_vote] = 1 # Add the model with a score of 1
# Write the updated scores back to the CSV file
with open(csv_file_path, 'w', newline='') as file:
writer = csv.writer(file)
writer.writerow(["Model Name", "Score"]) # Write the header row
for model, score in scores.items():
writer.writerow([model, score])
# Function to compare two models
def compare_models(prompt):
model1, model2 = sample(models, 2)
image1, model_name1 = gen_fn(model1, prompt)
image2, model_name2 = gen_fn(model2, prompt)
return image1, model_name1, image2, model_name2
# User voting logic
def handle_vote(user_vote):
init_model_scores()
# Assuming user_vote is a string indicating the preferred model
# Update ELO ratings based on user vote
update_elo_ratings(user_vote)
# Leaderboard display logic
def display_leaderboard():
# Logic to display leaderboard based on ELO ratings
pass
# Your existing Gradio setup code here...
def load_fn(models):
global models_load
models_load = {}
for model in models:
if model not in models_load.keys():
try:
m = gr.load(f'models/{model}')
except Exception as error:
m = gr.Interface(lambda txt: None, ['text'], ['image'])
models_load.update({model: m})
load_fn(models)
num_models = 6
default_models = models[:num_models]
def extend_choices(choices):
return choices + (num_models - len(choices)) * ['NA']
def update_imgbox(choices):
choices_plus = extend_choices(choices)
return [gr.Image(None, label = m, visible = (m != 'NA')) for m in choices_plus]
def gen_fn(model_str, prompt):
if model_str == 'NA':
return None
noise = str(randint(0, 99999999999))
return models_load[model_str](f'{prompt} {noise}')
# Modified gen_fn function to return model name
def gen_fn(model_str, prompt):
if model_str == 'NA':
return None, None
noise = str(randint(0, 99999999999))
image = models_load[model_str](f'{prompt} {noise}')
return image, model_str
with gr.Blocks() as ImageGenarationArena:
with gr.Column('model A', variant='panel', width=2, height=150) as col:
#with gr.Tab('model B'):
model_choice2 = gr.Dropdown(models, label = 'Choose model', value = models[0], filterable = False)
txt_input2 = gr.Textbox(label = 'Prompt text')
max_images = 6
num_images = gr.Slider(1, max_images, value = max_images, step = 1, label = 'Number of images')
gen_button2 = gr.Button('Generate')
stop_button2 = gr.Button('Stop', variant = 'secondary', interactive = False)
gen_button2.click(lambda s: gr.update(interactive = True), None, stop_button2)
with gr.Row():
output2 = [gr.Image(label = '') for _ in range(max_images)]
for i, o in enumerate(output2):
img_i = gr.Number(i, visible = False)
num_images.change(lambda i, n: gr.update(visible = (i < n)), [img_i, num_images], o)
gen_event2 = gen_button2.click(lambda i, n, m, t: gen_fn(m, t) if (i < n) else None, [img_i, num_images, model_choice2, txt_input2], o)
stop_button2.click(lambda s: gr.update(interactive = False), None, stop_button2, cancels = [gen_event2])
with gr.Column('model B', variant='panel', width=2, height=150) as col:
#with gr.Tab('model A'):
model_choice2 = gr.Dropdown(models, label = 'Choose model', value = models[0], filterable = False)
txt_input2 = gr.Textbox(label = 'Prompt text')
max_images = 6
num_images = gr.Slider(1, max_images, value = max_images, step = 1, label = 'Number of images')
gen_button2 = gr.Button('Generate')
stop_button2 = gr.Button('Stop', variant = 'secondary', interactive = False)
gen_button2.click(lambda s: gr.update(interactive = True), None, stop_button2)
with gr.Row():
output2 = [gr.Image(label = '') for _ in range(max_images)]
for i, o in enumerate(output2):
img_i = gr.Number(i, visible = False)
num_images.change(lambda i, n: gr.update(visible = (i < n)), [img_i, num_images], o)
gen_event2 = gen_button2.click(lambda i, n, m, t: gen_fn(m, t) if (i < n) else None, [img_i, num_images, model_choice2, txt_input2], o)
stop_button2.click(lambda s: gr.update(interactive = False), None, stop_button2, cancels = [gen_event2])
ImageGenarationArena.queue(concurrency_count = 36)
ImageGenarationArena.launch()