sohojoe commited on
Commit
b7b749f
1 Parent(s): 39f3733

implement power

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. app.py +24 -14
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -21,6 +21,7 @@ max_tabs = 10
21
  input_images = [None for i in range(max_tabs)]
22
  input_prompts = [None for i in range(max_tabs)]
23
  embedding_plots = [None for i in range(max_tabs)]
 
24
  # global embedding_base64s
25
  embedding_base64s = [None for i in range(max_tabs)]
26
  # embedding_base64s = gr.State(value=[None for i in range(max_tabs)])
@@ -136,19 +137,14 @@ def on_prompt_change_update_embeddings(prompt):
136
  embeddings_b64 = embedding_to_base64(embeddings)
137
  return gr.Text.update(embeddings_b64)
138
 
139
- # def on_embeddings_changed_update_average_embeddings(last_embedding_base64):
140
- # def on_embeddings_changed_update_average_embeddings(embedding_base64s):
141
- def on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_base64, idx):
142
- # global embedding_base64s
143
  final_embedding = None
144
  num_embeddings = 0
145
- embedding_base64s_state[idx] = embedding_base64 if embedding_base64 != '' else None
146
- # for textbox in embedding_base64s:
147
- # embedding_base64 = textbox.value
148
- for embedding_base64 in embedding_base64s_state:
149
  if embedding_base64 is None or embedding_base64 == "":
150
  continue
151
  embedding = base64_to_embedding(embedding_base64)
 
152
  if final_embedding is None:
153
  final_embedding = embedding
154
  else:
@@ -161,6 +157,16 @@ def on_embeddings_changed_update_average_embeddings(embedding_base64s_state, emb
161
  return gr.Text.update('')
162
  final_embedding = final_embedding / num_embeddings
163
  embeddings_b64 = embedding_to_base64(final_embedding)
 
 
 
 
 
 
 
 
 
 
164
  return gr.Text.update(embeddings_b64)
165
 
166
  def on_embeddings_changed_update_plot(embeddings_b64):
@@ -328,8 +334,13 @@ Try uploading a few images and/or add some text prompts and click generate image
328
  with gr.Column(scale=2, min_width=240):
329
  input_prompts[i] = gr.Textbox(label="Text Prompt", show_label=True)
330
  with gr.Column(scale=3, min_width=600):
331
- with gr.Accordion(f"Embeddings (base64)", open=False):
332
- embedding_base64s[i] = gr.Textbox(show_label=False)
 
 
 
 
 
333
  for idx, (tab_title, examples) in enumerate(tabbed_examples.items()):
334
  with gr.Tab(tab_title):
335
  with gr.Row():
@@ -372,15 +383,14 @@ Try uploading a few images and/or add some text prompts and click generate image
372
  output = gr.Gallery(label="Generated variations")
373
 
374
  embedding_base64s_state = gr.State(value=[None for i in range(max_tabs)])
 
375
  for i in range(max_tabs):
376
  input_images[i].change(on_image_load_update_embeddings, input_images[i], [embedding_base64s[i]])
377
- # input_prompts[i].submit(on_prompt_change_update_embeddings, input_prompts[i], [embedding_base64s[i]])
378
  input_prompts[i].change(on_prompt_change_update_embeddings, input_prompts[i], [embedding_base64s[i]])
379
  embedding_base64s[i].change(on_embeddings_changed_update_plot, embedding_base64s[i], [embedding_plots[i]])
380
- # embedding_plots[i].change(on_plot_changed, embedding_base64s[i], average_embedding_base64)
381
- # embedding_plots[i].change(on_embeddings_changed_update_average_embeddings, embedding_base64s[i], average_embedding_base64)
382
  idx_state = gr.State(value=i)
383
- embedding_base64s[i].change(on_embeddings_changed_update_average_embeddings, [embedding_base64s_state, embedding_base64s[i], idx_state], average_embedding_base64)
 
384
 
385
  average_embedding_base64.change(on_embeddings_changed_update_plot, average_embedding_base64, average_embedding_plot)
386
 
 
21
  input_images = [None for i in range(max_tabs)]
22
  input_prompts = [None for i in range(max_tabs)]
23
  embedding_plots = [None for i in range(max_tabs)]
24
+ embedding_powers = [1. for i in range(max_tabs)]
25
  # global embedding_base64s
26
  embedding_base64s = [None for i in range(max_tabs)]
27
  # embedding_base64s = gr.State(value=[None for i in range(max_tabs)])
 
137
  embeddings_b64 = embedding_to_base64(embeddings)
138
  return gr.Text.update(embeddings_b64)
139
 
140
+ def update_average_embeddings(embedding_base64s_state, embedding_powers):
 
 
 
141
  final_embedding = None
142
  num_embeddings = 0
143
+ for i, embedding_base64 in enumerate(embedding_base64s_state):
 
 
 
144
  if embedding_base64 is None or embedding_base64 == "":
145
  continue
146
  embedding = base64_to_embedding(embedding_base64)
147
+ embedding = embedding * embedding_powers[i]
148
  if final_embedding is None:
149
  final_embedding = embedding
150
  else:
 
157
  return gr.Text.update('')
158
  final_embedding = final_embedding / num_embeddings
159
  embeddings_b64 = embedding_to_base64(final_embedding)
160
+ return embeddings_b64
161
+
162
+ def on_power_change_update_average_embeddings(embedding_base64s_state, embedding_power_state, power, idx):
163
+ embedding_power_state[idx] = power
164
+ embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state)
165
+ return gr.Text.update(embeddings_b64)
166
+
167
+ def on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_power_state, embedding_base64, idx):
168
+ embedding_base64s_state[idx] = embedding_base64 if embedding_base64 != '' else None
169
+ embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state)
170
  return gr.Text.update(embeddings_b64)
171
 
172
  def on_embeddings_changed_update_plot(embeddings_b64):
 
334
  with gr.Column(scale=2, min_width=240):
335
  input_prompts[i] = gr.Textbox(label="Text Prompt", show_label=True)
336
  with gr.Column(scale=3, min_width=600):
337
+ with gr.Row():
338
+ # with gr.Slider(min=-5, max=5, value=1, label="Power", show_label=True):
339
+ # embedding_powers[i] = gr.Slider.value
340
+ embedding_powers[i] = gr.Slider(minimum=-3, maximum=3, value=1, label="Power", show_label=True, interactive=True)
341
+ with gr.Row():
342
+ with gr.Accordion(f"Embeddings (base64)", open=False):
343
+ embedding_base64s[i] = gr.Textbox(show_label=False)
344
  for idx, (tab_title, examples) in enumerate(tabbed_examples.items()):
345
  with gr.Tab(tab_title):
346
  with gr.Row():
 
383
  output = gr.Gallery(label="Generated variations")
384
 
385
  embedding_base64s_state = gr.State(value=[None for i in range(max_tabs)])
386
+ embedding_power_state = gr.State(value=[1. for i in range(max_tabs)])
387
  for i in range(max_tabs):
388
  input_images[i].change(on_image_load_update_embeddings, input_images[i], [embedding_base64s[i]])
 
389
  input_prompts[i].change(on_prompt_change_update_embeddings, input_prompts[i], [embedding_base64s[i]])
390
  embedding_base64s[i].change(on_embeddings_changed_update_plot, embedding_base64s[i], [embedding_plots[i]])
 
 
391
  idx_state = gr.State(value=i)
392
+ embedding_base64s[i].change(on_embeddings_changed_update_average_embeddings, [embedding_base64s_state, embedding_power_state, embedding_base64s[i], idx_state], average_embedding_base64)
393
+ embedding_powers[i].change(on_power_change_update_average_embeddings, [embedding_base64s_state, embedding_power_state, embedding_powers[i], idx_state], average_embedding_base64)
394
 
395
  average_embedding_base64.change(on_embeddings_changed_update_plot, average_embedding_base64, average_embedding_plot)
396