import gradio as gr from random import randint, sample from all_models import models import csv import os # Assuming you have a function to calculate ELO ratings def init_model_scores(file_path='model_scores.csv'): # Check if the CSV file exists, if not, create it with headers if not os.path.isfile(csv_file_path): with open(csv_file_path, 'w', newline='') as file: writer = csv.writer(file) writer.writerow(["Model Name", "Score"]) for model in models: # make a entry for each model writer.writerow([model, 0]) def update_elo_ratings(user_vote, csv_file_path='model_scores.csv'): # Logic to update ELO ratings based on user vote # Read the current scores from the CSV file scores = {} with open(csv_file_path, 'r') as file: reader = csv.reader(file) next(reader) # Skip the header row for row in reader: scores[row[0]] = int(row[1]) # Update the score for the selected model if user_vote in scores: scores[user_vote] += 1 # Increment the score else: scores[user_vote] = 1 # Add the model with a score of 1 # Write the updated scores back to the CSV file with open(csv_file_path, 'w', newline='') as file: writer = csv.writer(file) writer.writerow(["Model Name", "Score"]) # Write the header row for model, score in scores.items(): writer.writerow([model, score]) # Function to compare two models def compare_models(prompt): model1, model2 = sample(models, 2) image1, model_name1 = gen_fn(model1, prompt) image2, model_name2 = gen_fn(model2, prompt) return image1, model_name1, image2, model_name2 # User voting logic def handle_vote(user_vote): init_model_scores() # Assuming user_vote is a string indicating the preferred model # Update ELO ratings based on user vote update_elo_ratings(user_vote) # Leaderboard display logic def display_leaderboard(): # Logic to display leaderboard based on ELO ratings pass # Your existing Gradio setup code here... def load_fn(models): global models_load models_load = {} for model in models: if model not in models_load.keys(): try: m = gr.load(f'models/{model}') except Exception as error: m = gr.Interface(lambda txt: None, ['text'], ['image']) models_load.update({model: m}) load_fn(models) num_models = 6 default_models = models[:num_models] def extend_choices(choices): return choices + (num_models - len(choices)) * ['NA'] def update_imgbox(choices): choices_plus = extend_choices(choices) return [gr.Image(None, label = m, visible = (m != 'NA')) for m in choices_plus] def gen_fn(model_str, prompt): if model_str == 'NA': return None noise = str(randint(0, 99999999999)) return models_load[model_str](f'{prompt} {noise}') # Modified gen_fn function to return model name def gen_fn(model_str, prompt): if model_str == 'NA': return None, None noise = str(randint(0, 99999999999)) image = models_load[model_str](f'{prompt} {noise}') return image, model_str with gr.Blocks() as ImageGenarationArena: with gr.Column('model A', variant='panel', width=2, height=150) as col: #with gr.Tab('model B'): model_choice2 = gr.Dropdown(models, label = 'Choose model', value = models[0], filterable = False) txt_input2 = gr.Textbox(label = 'Prompt text') max_images = 6 num_images = gr.Slider(1, max_images, value = max_images, step = 1, label = 'Number of images') gen_button2 = gr.Button('Generate') stop_button2 = gr.Button('Stop', variant = 'secondary', interactive = False) gen_button2.click(lambda s: gr.update(interactive = True), None, stop_button2) with gr.Row(): output2 = [gr.Image(label = '') for _ in range(max_images)] for i, o in enumerate(output2): img_i = gr.Number(i, visible = False) num_images.change(lambda i, n: gr.update(visible = (i < n)), [img_i, num_images], o) gen_event2 = gen_button2.click(lambda i, n, m, t: gen_fn(m, t) if (i < n) else None, [img_i, num_images, model_choice2, txt_input2], o) stop_button2.click(lambda s: gr.update(interactive = False), None, stop_button2, cancels = [gen_event2]) with gr.Column('model B', variant='panel', width=2, height=150) as col: #with gr.Tab('model A'): model_choice2 = gr.Dropdown(models, label = 'Choose model', value = models[0], filterable = False) txt_input2 = gr.Textbox(label = 'Prompt text') max_images = 6 num_images = gr.Slider(1, max_images, value = max_images, step = 1, label = 'Number of images') gen_button2 = gr.Button('Generate') stop_button2 = gr.Button('Stop', variant = 'secondary', interactive = False) gen_button2.click(lambda s: gr.update(interactive = True), None, stop_button2) with gr.Row(): output2 = [gr.Image(label = '') for _ in range(max_images)] for i, o in enumerate(output2): img_i = gr.Number(i, visible = False) num_images.change(lambda i, n: gr.update(visible = (i < n)), [img_i, num_images], o) gen_event2 = gen_button2.click(lambda i, n, m, t: gen_fn(m, t) if (i < n) else None, [img_i, num_images, model_choice2, txt_input2], o) stop_button2.click(lambda s: gr.update(interactive = False), None, stop_button2, cancels = [gen_event2]) ImageGenarationArena.queue(concurrency_count = 36) ImageGenarationArena.launch()