File size: 5,777 Bytes
374ecac
ec8ef23
374ecac
ec8ef23
 
374ecac
ec8ef23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374ecac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7169302
374ecac
 
 
 
 
 
 
 
 
527535b
374ecac
 
8fd30e2
374ecac
 
 
8fd30e2
ec8ef23
 
 
 
 
 
 
 
374ecac
 
ec8ef23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc23ea5
ec8ef23
 
 
 
 
374ecac
ec8ef23
 
374ecac
 
4735bd9
 
 
 
374ecac
4735bd9
 
 
374ecac
4735bd9
374ecac
4735bd9
 
 
 
 
374ecac
4735bd9
8613c00
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
from random import randint, sample
from all_models import models
import csv
import os

# Assuming you have a function to calculate ELO ratings
def init_model_scores(file_path='model_scores.csv'):
    # Check if the CSV file exists, if not, create it with headers
    if not os.path.isfile(csv_file_path):
        with open(csv_file_path, 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Model Name", "Score"])
            for model in models:
                # make a entry for each model
                writer.writerow([model, 0])
    
def update_elo_ratings(user_vote, csv_file_path='model_scores.csv'):
    # Logic to update ELO ratings based on user vote
    
    # Read the current scores from the CSV file
    scores = {}
    with open(csv_file_path, 'r') as file:
        reader = csv.reader(file)
        next(reader) # Skip the header row
        for row in reader:
            scores[row[0]] = int(row[1])
    
    # Update the score for the selected model
    if user_vote in scores:
        scores[user_vote] += 1 # Increment the score
    else:
        scores[user_vote] = 1 # Add the model with a score of 1
    
    # Write the updated scores back to the CSV file
    with open(csv_file_path, 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["Model Name", "Score"]) # Write the header row
        for model, score in scores.items():
            writer.writerow([model, score])


# Function to compare two models
def compare_models(prompt):
    model1, model2 = sample(models, 2)
    image1, model_name1 = gen_fn(model1, prompt)
    image2, model_name2 = gen_fn(model2, prompt)
    return image1, model_name1, image2, model_name2

# User voting logic
def handle_vote(user_vote):
    init_model_scores()
    # Assuming user_vote is a string indicating the preferred model
    # Update ELO ratings based on user vote
    update_elo_ratings(user_vote)
    
# Leaderboard display logic
def display_leaderboard():
    # Logic to display leaderboard based on ELO ratings
    pass

# Your existing Gradio setup code here...
def load_fn(models):
    global models_load
    models_load = {}
    
    for model in models:
        if model not in models_load.keys():
            try:
                m = gr.load(f'models/{model}')
            except Exception as error:
                m = gr.Interface(lambda txt: None, ['text'], ['image'])
            models_load.update({model: m})


load_fn(models)


num_models = 6
default_models = models[:num_models]


def extend_choices(choices):
    return choices + (num_models - len(choices)) * ['NA']


def update_imgbox(choices):
    choices_plus = extend_choices(choices)
    return [gr.Image(None, label = m, visible = (m != 'NA')) for m in choices_plus]

    
def gen_fn(model_str, prompt):
    if model_str == 'NA':
        return None
    noise = str(randint(0, 99999999999))
    return models_load[model_str](f'{prompt} {noise}')
    
# Modified gen_fn function to return model name
def gen_fn(model_str, prompt):
    if model_str == 'NA':
        return None, None
    noise = str(randint(0, 99999999999))
    image = models_load[model_str](f'{prompt} {noise}')
    return image, model_str


with gr.Blocks() as ImageGenarationArena:
    with gr.Column('model A', variant='panel', width=2, height=150) as col:
    #with gr.Tab('model B'):
        model_choice2 = gr.Dropdown(models, label = 'Choose model', value = models[0], filterable = False)
        txt_input2 = gr.Textbox(label = 'Prompt text')
        
        max_images = 6
        num_images = gr.Slider(1, max_images, value = max_images, step = 1, label = 'Number of images')
        
        gen_button2 = gr.Button('Generate')
        stop_button2 = gr.Button('Stop', variant = 'secondary', interactive = False)
        gen_button2.click(lambda s: gr.update(interactive = True), None, stop_button2)
        
        with gr.Row():
            output2 = [gr.Image(label = '') for _ in range(max_images)]

        for i, o in enumerate(output2):
            img_i = gr.Number(i, visible = False)
            num_images.change(lambda i, n: gr.update(visible = (i < n)), [img_i, num_images], o)
            gen_event2 = gen_button2.click(lambda i, n, m, t: gen_fn(m, t) if (i < n) else None, [img_i, num_images, model_choice2, txt_input2], o)
            stop_button2.click(lambda s: gr.update(interactive = False), None, stop_button2, cancels = [gen_event2])

    with gr.Column('model B', variant='panel', width=2, height=150) as col:    
    #with gr.Tab('model A'):
        model_choice2 = gr.Dropdown(models, label = 'Choose model', value = models[0], filterable = False)
        txt_input2 = gr.Textbox(label = 'Prompt text')
        
        max_images = 6
        num_images = gr.Slider(1, max_images, value = max_images, step = 1, label = 'Number of images')
        
        gen_button2 = gr.Button('Generate')
        stop_button2 = gr.Button('Stop', variant = 'secondary', interactive = False)
        gen_button2.click(lambda s: gr.update(interactive = True), None, stop_button2)
        
        with gr.Row():
            output2 = [gr.Image(label = '') for _ in range(max_images)]

        for i, o in enumerate(output2):
            img_i = gr.Number(i, visible = False)
            num_images.change(lambda i, n: gr.update(visible = (i < n)), [img_i, num_images], o)
            gen_event2 = gen_button2.click(lambda i, n, m, t: gen_fn(m, t) if (i < n) else None, [img_i, num_images, model_choice2, txt_input2], o)
            stop_button2.click(lambda s: gr.update(interactive = False), None, stop_button2, cancels = [gen_event2])

    
ImageGenarationArena.queue(concurrency_count = 36)                        
ImageGenarationArena.launch()