K00B404 commited on
Commit
ec8ef23
1 Parent(s): 85f8466

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -22
app.py CHANGED
@@ -1,7 +1,65 @@
1
  import gradio as gr
2
- from random import randint
3
  from all_models import models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
5
  def load_fn(models):
6
  global models_load
7
  models_load = {}
@@ -36,31 +94,40 @@ def gen_fn(model_str, prompt):
36
  return None
37
  noise = str(randint(0, 99999999999))
38
  return models_load[model_str](f'{prompt} {noise}')
 
 
 
 
 
 
 
 
39
 
40
 
41
- with gr.Blocks() as demo:
42
- with gr.Tab('Multiple models'):
43
- with gr.Accordion('Model selection'):
44
- model_choice = gr.CheckboxGroup(models, label = f'Choose up to {num_models} different models', value = default_models, multiselect = True, max_choices = num_models, interactive = True, filterable = False)
 
 
 
 
 
 
 
 
 
 
 
45
 
 
 
 
 
 
46
 
47
- txt_input = gr.Textbox(label = 'Prompt text')
48
- gen_button = gr.Button('Generate')
49
- stop_button = gr.Button('Stop', variant = 'secondary', interactive = False)
50
- gen_button.click(lambda s: gr.update(interactive = True), None, stop_button)
51
-
52
- with gr.Row():
53
- output = [gr.Image(label = m) for m in default_models]
54
- current_models = [gr.Textbox(m, visible = False) for m in default_models]
55
-
56
- model_choice.change(update_imgbox, model_choice, output)
57
- model_choice.change(extend_choices, model_choice, current_models)
58
-
59
- for m, o in zip(current_models, output):
60
- gen_event = gen_button.click(gen_fn, [m, txt_input], o, queue=False)
61
-
62
-
63
- with gr.Tab('Single model'):
64
  model_choice2 = gr.Dropdown(models, label = 'Choose model', value = models[0], filterable = False)
65
  txt_input2 = gr.Textbox(label = 'Prompt text')
66
 
 
1
  import gradio as gr
2
+ from random import randint, sample
3
  from all_models import models
4
+ import csv
5
+ import os
6
+
7
+ # Assuming you have a function to calculate ELO ratings
8
+ def init_model_scores(file_path='model_scores.csv'):
9
+ # Check if the CSV file exists, if not, create it with headers
10
+ if not os.path.isfile(csv_file_path):
11
+ with open(csv_file_path, 'w', newline='') as file:
12
+ writer = csv.writer(file)
13
+ writer.writerow(["Model Name", "Score"])
14
+ for model in models:
15
+ # make a entry for each model
16
+ writer.writerow([model, 0])
17
+
18
+ def update_elo_ratings(user_vote, csv_file_path='model_scores.csv'):
19
+ # Logic to update ELO ratings based on user vote
20
+
21
+ # Read the current scores from the CSV file
22
+ scores = {}
23
+ with open(csv_file_path, 'r') as file:
24
+ reader = csv.reader(file)
25
+ next(reader) # Skip the header row
26
+ for row in reader:
27
+ scores[row[0]] = int(row[1])
28
+
29
+ # Update the score for the selected model
30
+ if user_vote in scores:
31
+ scores[user_vote] += 1 # Increment the score
32
+ else:
33
+ scores[user_vote] = 1 # Add the model with a score of 1
34
+
35
+ # Write the updated scores back to the CSV file
36
+ with open(csv_file_path, 'w', newline='') as file:
37
+ writer = csv.writer(file)
38
+ writer.writerow(["Model Name", "Score"]) # Write the header row
39
+ for model, score in scores.items():
40
+ writer.writerow([model, score])
41
+
42
+
43
+ # Function to compare two models
44
+ def compare_models(prompt):
45
+ model1, model2 = sample(models, 2)
46
+ image1, model_name1 = gen_fn(model1, prompt)
47
+ image2, model_name2 = gen_fn(model2, prompt)
48
+ return image1, model_name1, image2, model_name2
49
+
50
+ # User voting logic
51
+ def handle_vote(user_vote):
52
+ init_model_scores()
53
+ # Assuming user_vote is a string indicating the preferred model
54
+ # Update ELO ratings based on user vote
55
+ update_elo_ratings(user_vote)
56
+
57
+ # Leaderboard display logic
58
+ def display_leaderboard():
59
+ # Logic to display leaderboard based on ELO ratings
60
+ pass
61
 
62
+ # Your existing Gradio setup code here...
63
  def load_fn(models):
64
  global models_load
65
  models_load = {}
 
94
  return None
95
  noise = str(randint(0, 99999999999))
96
  return models_load[model_str](f'{prompt} {noise}')
97
+
98
+ # Modified gen_fn function to return model name
99
+ def gen_fn(model_str, prompt):
100
+ if model_str == 'NA':
101
+ return None, None
102
+ noise = str(randint(0, 99999999999))
103
+ image = models_load[model_str](f'{prompt} {noise}')
104
+ return image, model_str
105
 
106
 
107
+ with gr.Blocks() as ImageGenarationArena:
108
+ with gr.Column('model A', variant='panel', width=2, height=150) as col:
109
+ #with gr.Tab('model B'):
110
+ model_choice2 = gr.Dropdown(models, label = 'Choose model', value = models[0], filterable = False)
111
+ txt_input2 = gr.Textbox(label = 'Prompt text')
112
+
113
+ max_images = 6
114
+ num_images = gr.Slider(1, max_images, value = max_images, step = 1, label = 'Number of images')
115
+
116
+ gen_button2 = gr.Button('Generate')
117
+ stop_button2 = gr.Button('Stop', variant = 'secondary', interactive = False)
118
+ gen_button2.click(lambda s: gr.update(interactive = True), None, stop_button2)
119
+
120
+ with gr.Row():
121
+ output2 = [gr.Image(label = '') for _ in range(max_images)]
122
 
123
+ for i, o in enumerate(output2):
124
+ img_i = gr.Number(i, visible = False)
125
+ num_images.change(lambda i, n: gr.update(visible = (i < n)), [img_i, num_images], o)
126
+ gen_event2 = gen_button2.click(lambda i, n, m, t: gen_fn(m, t) if (i < n) else None, [img_i, num_images, model_choice2, txt_input2], o)
127
+ stop_button2.click(lambda s: gr.update(interactive = False), None, stop_button2, cancels = [gen_event2])
128
 
129
+ with gr.Column('model B', variant='panel', width=2, height=150) as col:
130
+ #with gr.Tab('model A'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  model_choice2 = gr.Dropdown(models, label = 'Choose model', value = models[0], filterable = False)
132
  txt_input2 = gr.Textbox(label = 'Prompt text')
133