sohojoe commited on
Commit
80c0743
1 Parent(s): 76f0068

fix UI for gradio upgrade

Browse files
Files changed (1) hide show
  1. app.py +84 -17
app.py CHANGED
@@ -31,8 +31,14 @@ embedding_powers = [1. for i in range(max_tabs)]
31
  embedding_base64s = [None for i in range(max_tabs)]
32
  # embedding_base64s = gr.State(value=[None for i in range(max_tabs)])
33
 
 
 
 
 
 
34
 
35
  def image_to_embedding(input_im):
 
36
  input_im = Image.fromarray(input_im)
37
  prepro = preprocess(input_im).unsqueeze(0).to(device)
38
  with torch.no_grad():
@@ -42,6 +48,7 @@ def image_to_embedding(input_im):
42
  return image_embeddings_np
43
 
44
  def prompt_to_embedding(prompt):
 
45
  text = tokenizer([prompt]).to(device)
46
  with torch.no_grad():
47
  prompt_embededdings = model.encode_text(text)
@@ -50,6 +57,7 @@ def prompt_to_embedding(prompt):
50
  return prompt_embededdings_np
51
 
52
  def embedding_to_image(embeddings):
 
53
  size = math.ceil(math.sqrt(embeddings.shape[0]))
54
  image_embeddings_square = np.pad(embeddings, (0, size**2 - embeddings.shape[0]), 'constant')
55
  image_embeddings_square.resize(size,size)
@@ -57,6 +65,7 @@ def embedding_to_image(embeddings):
57
  return embedding_image
58
 
59
  def embedding_to_base64(embeddings):
 
60
  import base64
61
  # ensure float32
62
  embeddings = embeddings.astype(np.float32)
@@ -64,12 +73,22 @@ def embedding_to_base64(embeddings):
64
  return embeddings_b64
65
 
66
  def base64_to_embedding(embeddings_b64):
 
67
  import base64
68
  embeddings = base64.urlsafe_b64decode(embeddings_b64)
69
  embeddings = np.frombuffer(embeddings, dtype=np.float32)
70
  # embeddings = torch.tensor(embeddings)
71
  return embeddings
72
 
 
 
 
 
 
 
 
 
 
73
  def safe_url(url):
74
  import urllib.parse
75
  url = urllib.parse.quote(url, safe=':/')
@@ -83,6 +102,7 @@ def main(
83
  embeddings,
84
  n_samples=4,
85
  ):
 
86
 
87
  embeddings = base64_to_embedding(embeddings)
88
  # convert to python array
@@ -117,17 +137,21 @@ def main(
117
  return images
118
 
119
  def on_image_load_update_embeddings(image_data):
 
120
  # image to embeddings
121
  if image_data is None:
122
  # embeddings = prompt_to_embedding('')
123
  # embeddings_b64 = embedding_to_base64(embeddings)
124
  # return gr.Text.update(embeddings_b64)
125
- return gr.Text.update('')
 
126
  embeddings = image_to_embedding(image_data)
127
  embeddings_b64 = embedding_to_base64(embeddings)
128
- return gr.Text.update(embeddings_b64)
 
129
 
130
  def on_prompt_change_update_embeddings(prompt):
 
131
  # prompt to embeddings
132
  if prompt is None or prompt == "":
133
  embeddings = prompt_to_embedding('')
@@ -135,9 +159,10 @@ def on_prompt_change_update_embeddings(prompt):
135
  return gr.Text.update(embedding_to_base64(embeddings))
136
  embeddings = prompt_to_embedding(prompt)
137
  embeddings_b64 = embedding_to_base64(embeddings)
138
- return gr.Text.update(embeddings_b64)
139
 
140
  def update_average_embeddings(embedding_base64s_state, embedding_powers):
 
141
  final_embedding = None
142
  num_embeddings = 0
143
  for i, embedding_base64 in enumerate(embedding_base64s_state):
@@ -154,7 +179,7 @@ def update_average_embeddings(embedding_base64s_state, embedding_powers):
154
  # embeddings = prompt_to_embedding('')
155
  # embeddings_b64 = embedding_to_base64(embeddings)
156
  # return gr.Text.update(embeddings_b64)
157
- return gr.Text.update('')
158
 
159
  # TODO toggle this to support average or sum
160
  # final_embedding = final_embedding / num_embeddings
@@ -166,22 +191,25 @@ def update_average_embeddings(embedding_base64s_state, embedding_powers):
166
  return embeddings_b64
167
 
168
  def on_power_change_update_average_embeddings(embedding_base64s_state, embedding_power_state, power, idx):
 
169
  embedding_power_state[idx] = power
170
  embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state)
171
- return gr.Text.update(embeddings_b64)
172
 
173
  def on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_power_state, embedding_base64, idx):
 
174
  embedding_base64s_state[idx] = embedding_base64 if embedding_base64 != '' else None
175
  embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state)
176
- return gr.Text.update(embeddings_b64)
177
 
178
  def on_embeddings_changed_update_plot(embeddings_b64):
 
179
  # plot new embeddings
180
  if embeddings_b64 is None or embeddings_b64 == "":
181
  data = pd.DataFrame({
182
  'embedding': [],
183
  'index': []})
184
- return gr.LinePlot.update(data,
185
  x="index",
186
  y="embedding",
187
  # color="country",
@@ -192,6 +220,7 @@ def on_embeddings_changed_update_plot(embeddings_b64):
192
  # stroke_dash_legend_title="Country Cluster",
193
  # height=300,
194
  width=0)
 
195
 
196
  embeddings = base64_to_embedding(embeddings_b64)
197
  data = pd.DataFrame({
@@ -210,6 +239,7 @@ def on_embeddings_changed_update_plot(embeddings_b64):
210
  width=embeddings.shape[0])
211
 
212
  def on_example_image_click_set_image(input_image, image_url):
 
213
  input_image.value = image_url
214
 
215
  # device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")
@@ -236,7 +266,7 @@ examples = [
236
  # ["SohoJoeEth.jpeg", "Snoop Dogg.jpg", "SohoJoeEth + Snoop Dogg.jpeg"],
237
  ["pup1.jpg", "", "Pup no teacup.jpg"],
238
  ]
239
- tile_size = 100
240
  # image_folder = os.path.join("file", "images")
241
  image_folder ="images"
242
 
@@ -349,7 +379,7 @@ Try uploading a few images and/or add some text prompts and search the embedding
349
  # input_image.change(on_image_load, inputs= [input_image, plot])
350
  with gr.Row():
351
  with gr.Column(scale=2, min_width=240):
352
- input_prompts[i] = gr.Textbox(label="Text Prompt", show_label=True)
353
  with gr.Column(scale=3, min_width=600):
354
  with gr.Row():
355
  # with gr.Slider(min=-5, max=5, value=1, label="Power", show_label=True):
@@ -357,7 +387,7 @@ Try uploading a few images and/or add some text prompts and search the embedding
357
  embedding_powers[i] = gr.Slider(minimum=-3, maximum=3, value=1, label="Power", show_label=True, interactive=True)
358
  with gr.Row():
359
  with gr.Accordion(f"Embeddings (base64)", open=False):
360
- embedding_base64s[i] = gr.Textbox(show_label=False)
361
  for idx, (tab_title, examples) in enumerate(tabbed_examples.items()):
362
  with gr.Tab(tab_title):
363
  with gr.Row():
@@ -395,15 +425,52 @@ Try uploading a few images and/or add some text prompts and search the embedding
395
 
396
  embedding_base64s_state = gr.State(value=[None for i in range(max_tabs)])
397
  embedding_power_state = gr.State(value=[1. for i in range(max_tabs)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  for i in range(max_tabs):
399
- input_images[i].change(on_image_load_update_embeddings, input_images[i], [embedding_base64s[i]])
400
- input_prompts[i].change(on_prompt_change_update_embeddings, input_prompts[i], [embedding_base64s[i]])
401
- embedding_base64s[i].change(on_embeddings_changed_update_plot, embedding_base64s[i], [embedding_plots[i]])
402
  idx_state = gr.State(value=i)
403
- embedding_base64s[i].change(on_embeddings_changed_update_average_embeddings, [embedding_base64s_state, embedding_power_state, embedding_base64s[i], idx_state], average_embedding_base64)
404
- embedding_powers[i].change(on_power_change_update_average_embeddings, [embedding_base64s_state, embedding_power_state, embedding_powers[i], idx_state], average_embedding_base64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
- average_embedding_base64.change(on_embeddings_changed_update_plot, average_embedding_base64, average_embedding_plot)
407
 
408
  # submit.click(main, inputs= [embedding_base64s[0], scale, n_samples, steps, seed], outputs=output)
409
  submit.click(main, inputs= [average_embedding_base64, n_samples], outputs=output)
@@ -439,4 +506,4 @@ My interest is to use CLIP for image/video understanding (see [CLIP_visual-spati
439
  # ![Alt Text](file/pup1.jpg){height=100 width=100}
440
 
441
  if __name__ == "__main__":
442
- demo.launch()
 
31
  embedding_base64s = [None for i in range(max_tabs)]
32
  # embedding_base64s = gr.State(value=[None for i in range(max_tabs)])
33
 
34
+ debug_print_on = False
35
+
36
+ def debug_print(*args, **kwargs):
37
+ if debug_print_on:
38
+ print(*args, **kwargs)
39
 
40
  def image_to_embedding(input_im):
41
+ # debug_print("image_to_embedding")
42
  input_im = Image.fromarray(input_im)
43
  prepro = preprocess(input_im).unsqueeze(0).to(device)
44
  with torch.no_grad():
 
48
  return image_embeddings_np
49
 
50
  def prompt_to_embedding(prompt):
51
+ # debug_print("prompt_to_embedding")
52
  text = tokenizer([prompt]).to(device)
53
  with torch.no_grad():
54
  prompt_embededdings = model.encode_text(text)
 
57
  return prompt_embededdings_np
58
 
59
  def embedding_to_image(embeddings):
60
+ # debug_print("embedding_to_image")
61
  size = math.ceil(math.sqrt(embeddings.shape[0]))
62
  image_embeddings_square = np.pad(embeddings, (0, size**2 - embeddings.shape[0]), 'constant')
63
  image_embeddings_square.resize(size,size)
 
65
  return embedding_image
66
 
67
  def embedding_to_base64(embeddings):
68
+ # debug_print("embedding_to_base64")
69
  import base64
70
  # ensure float32
71
  embeddings = embeddings.astype(np.float32)
 
73
  return embeddings_b64
74
 
75
  def base64_to_embedding(embeddings_b64):
76
+ # debug_print("base64_to_embedding")
77
  import base64
78
  embeddings = base64.urlsafe_b64decode(embeddings_b64)
79
  embeddings = np.frombuffer(embeddings, dtype=np.float32)
80
  # embeddings = torch.tensor(embeddings)
81
  return embeddings
82
 
83
+ def is_prompt_embeddings(prompt):
84
+ if prompt is None or prompt == "":
85
+ return False
86
+ try:
87
+ embedding = base64_to_embedding(prompt)
88
+ return True
89
+ except Exception as e:
90
+ return False
91
+
92
  def safe_url(url):
93
  import urllib.parse
94
  url = urllib.parse.quote(url, safe=':/')
 
102
  embeddings,
103
  n_samples=4,
104
  ):
105
+ debug_print("main")
106
 
107
  embeddings = base64_to_embedding(embeddings)
108
  # convert to python array
 
137
  return images
138
 
139
  def on_image_load_update_embeddings(image_data):
140
+ debug_print("on_image_load_update_embeddings")
141
  # image to embeddings
142
  if image_data is None:
143
  # embeddings = prompt_to_embedding('')
144
  # embeddings_b64 = embedding_to_base64(embeddings)
145
  # return gr.Text.update(embeddings_b64)
146
+ # return gr.Text.update('')
147
+ return ''
148
  embeddings = image_to_embedding(image_data)
149
  embeddings_b64 = embedding_to_base64(embeddings)
150
+ # return gr.Text.update(embeddings_b64)
151
+ return embeddings_b64
152
 
153
  def on_prompt_change_update_embeddings(prompt):
154
+ debug_print("on_prompt_change_update_embeddings")
155
  # prompt to embeddings
156
  if prompt is None or prompt == "":
157
  embeddings = prompt_to_embedding('')
 
159
  return gr.Text.update(embedding_to_base64(embeddings))
160
  embeddings = prompt_to_embedding(prompt)
161
  embeddings_b64 = embedding_to_base64(embeddings)
162
+ return embeddings_b64
163
 
164
  def update_average_embeddings(embedding_base64s_state, embedding_powers):
165
+ debug_print("update_average_embeddings")
166
  final_embedding = None
167
  num_embeddings = 0
168
  for i, embedding_base64 in enumerate(embedding_base64s_state):
 
179
  # embeddings = prompt_to_embedding('')
180
  # embeddings_b64 = embedding_to_base64(embeddings)
181
  # return gr.Text.update(embeddings_b64)
182
+ return ''
183
 
184
  # TODO toggle this to support average or sum
185
  # final_embedding = final_embedding / num_embeddings
 
191
  return embeddings_b64
192
 
193
  def on_power_change_update_average_embeddings(embedding_base64s_state, embedding_power_state, power, idx):
194
+ debug_print("on_power_change_update_average_embeddings")
195
  embedding_power_state[idx] = power
196
  embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state)
197
+ return embeddings_b64
198
 
199
  def on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_power_state, embedding_base64, idx):
200
+ debug_print("on_embeddings_changed_update_average_embeddings")
201
  embedding_base64s_state[idx] = embedding_base64 if embedding_base64 != '' else None
202
  embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state)
203
+ return embeddings_b64
204
 
205
  def on_embeddings_changed_update_plot(embeddings_b64):
206
+ debug_print("on_embeddings_changed_update_plot")
207
  # plot new embeddings
208
  if embeddings_b64 is None or embeddings_b64 == "":
209
  data = pd.DataFrame({
210
  'embedding': [],
211
  'index': []})
212
+ update = gr.LinePlot.update(data,
213
  x="index",
214
  y="embedding",
215
  # color="country",
 
220
  # stroke_dash_legend_title="Country Cluster",
221
  # height=300,
222
  width=0)
223
+ return update
224
 
225
  embeddings = base64_to_embedding(embeddings_b64)
226
  data = pd.DataFrame({
 
239
  width=embeddings.shape[0])
240
 
241
  def on_example_image_click_set_image(input_image, image_url):
242
+ debug_print("on_example_image_click_set_image")
243
  input_image.value = image_url
244
 
245
  # device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")
 
266
  # ["SohoJoeEth.jpeg", "Snoop Dogg.jpg", "SohoJoeEth + Snoop Dogg.jpeg"],
267
  ["pup1.jpg", "", "Pup no teacup.jpg"],
268
  ]
269
+ tile_size = 110
270
  # image_folder = os.path.join("file", "images")
271
  image_folder ="images"
272
 
 
379
  # input_image.change(on_image_load, inputs= [input_image, plot])
380
  with gr.Row():
381
  with gr.Column(scale=2, min_width=240):
382
+ input_prompts[i] = gr.Textbox(label="Text Prompt", show_label=True, max_lines=4)
383
  with gr.Column(scale=3, min_width=600):
384
  with gr.Row():
385
  # with gr.Slider(min=-5, max=5, value=1, label="Power", show_label=True):
 
387
  embedding_powers[i] = gr.Slider(minimum=-3, maximum=3, value=1, label="Power", show_label=True, interactive=True)
388
  with gr.Row():
389
  with gr.Accordion(f"Embeddings (base64)", open=False):
390
+ embedding_base64s[i] = gr.Textbox(show_label=False, live=True)
391
  for idx, (tab_title, examples) in enumerate(tabbed_examples.items()):
392
  with gr.Tab(tab_title):
393
  with gr.Row():
 
425
 
426
  embedding_base64s_state = gr.State(value=[None for i in range(max_tabs)])
427
  embedding_power_state = gr.State(value=[1. for i in range(max_tabs)])
428
+
429
+ def on_image_load(input_image, idx_state, embedding_base64s_state, embedding_power_state):
430
+ debug_print("on_image_load")
431
+ embeddings_b64 = on_image_load_update_embeddings(input_image)
432
+ new_plot = on_embeddings_changed_update_plot(embeddings_b64)
433
+ average_embeddings_b64 = on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_power_state, embeddings_b64, idx_state)
434
+ new_average_plot = on_embeddings_changed_update_plot(average_embeddings_b64)
435
+ return embeddings_b64, new_plot, average_embeddings_b64, new_average_plot
436
+
437
+ def on_prompt_change(prompt, idx_state, embedding_base64s_state, embedding_power_state):
438
+ debug_print("on_prompt_change")
439
+ if is_prompt_embeddings(prompt):
440
+ embeddings_b64 = prompt
441
+ else:
442
+ embeddings_b64 = on_prompt_change_update_embeddings(prompt)
443
+ new_plot = on_embeddings_changed_update_plot(embeddings_b64)
444
+ average_embeddings_b64 = on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_power_state, embeddings_b64, idx_state)
445
+ new_average_plot = on_embeddings_changed_update_plot(average_embeddings_b64)
446
+ return embeddings_b64, new_plot, average_embeddings_b64, new_average_plot
447
+
448
+ def on_power_change(power, idx_state, embedding_base64s_state, embedding_power_state):
449
+ debug_print("on_power_change")
450
+ average_embeddings_b64 = on_power_change_update_average_embeddings(embedding_base64s_state, embedding_power_state, power, idx_state)
451
+ new_average_plot = on_embeddings_changed_update_plot(average_embeddings_b64)
452
+ return average_embeddings_b64, new_average_plot
453
+
454
  for i in range(max_tabs):
 
 
 
455
  idx_state = gr.State(value=i)
456
+ input_images[i].change(on_image_load,
457
+ [input_images[i], idx_state, embedding_base64s_state, embedding_power_state],
458
+ [embedding_base64s[i], embedding_plots[i], average_embedding_base64, average_embedding_plot])
459
+ input_prompts[i].change(on_prompt_change,
460
+ [input_prompts[i], idx_state, embedding_base64s_state, embedding_power_state],
461
+ [embedding_base64s[i], embedding_plots[i], average_embedding_base64, average_embedding_plot])
462
+ embedding_powers[i].change(on_power_change,
463
+ [embedding_powers[i], idx_state, embedding_base64s_state, embedding_power_state],
464
+ [average_embedding_base64, average_embedding_plot])
465
+
466
+
467
+ # input_images[i].change(on_image_load_update_embeddings, input_images[i], embedding_base64s[i])
468
+ # input_prompts[i].change(on_prompt_change_update_embeddings, input_prompts[i], embedding_base64s[i])
469
+ # embedding_base64s[i].change(on_embeddings_changed_update_plot, embedding_base64s[i], embedding_plots[i])
470
+ # embedding_base64s[i].change(on_embeddings_changed_update_average_embeddings, [embedding_base64s_state, embedding_power_state, embedding_base64s[i], idx_state], average_embedding_base64)
471
+ # embedding_powers[i].change(on_power_change_update_average_embeddings, [embedding_base64s_state, embedding_power_state, embedding_powers[i], idx_state], average_embedding_base64)
472
 
473
+ # average_embedding_base64.change(on_embeddings_changed_update_plot, average_embedding_base64, average_embedding_plot)
474
 
475
  # submit.click(main, inputs= [embedding_base64s[0], scale, n_samples, steps, seed], outputs=output)
476
  submit.click(main, inputs= [average_embedding_base64, n_samples], outputs=output)
 
506
  # ![Alt Text](file/pup1.jpg){height=100 width=100}
507
 
508
  if __name__ == "__main__":
509
+ demo.launch(debug=True)