sohojoe commited on
Commit
74f54c9
1 Parent(s): a14ceae

loads n closet via api endpoint

Browse files
Files changed (1) hide show
  1. app.py +22 -26
app.py CHANGED
@@ -79,30 +79,26 @@ def main(
79
  seed=None
80
  ):
81
 
82
- if seed == None:
83
- seed = np.random.randint(2147483647)
84
- # if device contains cuda
85
- if device.type == 'cuda':
86
- generator = torch.Generator(device=device).manual_seed(int(seed))
87
- else:
88
- generator = torch.Generator().manual_seed(int(seed)) # use cpu as does not work on mps
89
-
90
  embeddings = base64_to_embedding(embeddings)
91
- embeddings = torch.tensor(embeddings, dtype=torch_size).to(device)
92
-
93
- images_list = pipe(
94
- # inp.tile(n_samples, 1, 1, 1),
95
- # [embeddings * n_samples],
96
- embeddings,
97
- guidance_scale=scale,
98
- num_inference_steps=steps,
99
- generator=generator,
100
- )
101
-
102
  images = []
103
- for i, image in enumerate(images_list["images"]):
104
- images.append(image)
105
- # images.append(embedding_image)
 
 
 
 
 
 
 
 
 
 
 
 
106
  return images
107
 
108
  def on_image_load_update_embeddings(image_data):
@@ -146,10 +142,10 @@ def update_average_embeddings(embedding_base64s_state, embedding_powers):
146
  return gr.Text.update('')
147
 
148
  # TODO toggle this to support average or sum
149
- final_embedding = final_embedding / num_embeddings
150
 
151
  # normalize embeddings in numpy
152
- final_embedding /= np.linalg.norm(final_embedding)
153
 
154
  embeddings_b64 = embedding_to_base64(final_embedding)
155
  return embeddings_b64
@@ -368,12 +364,12 @@ Try uploading a few images and/or add some text prompts and click generate image
368
  with gr.Accordion(f"Avergage embeddings in base 64", open=False):
369
  average_embedding_base64 = gr.Textbox(show_label=False)
370
  with gr.Row():
371
- submit = gr.Button("Generate images")
372
  with gr.Row():
373
  with gr.Column(scale=1, min_width=200):
374
  scale = gr.Slider(0, 25, value=3, step=1, label="Guidance scale")
375
  with gr.Column(scale=1, min_width=200):
376
- n_samples = gr.Slider(1, 4, value=1, step=1, label="Number images")
377
  with gr.Column(scale=1, min_width=200):
378
  steps = gr.Slider(5, 50, value=25, step=5, label="Steps")
379
  with gr.Column(scale=1, min_width=200):
 
79
  seed=None
80
  ):
81
 
 
 
 
 
 
 
 
 
82
  embeddings = base64_to_embedding(embeddings)
83
+ # convert to python array
84
+ embeddings = embeddings.tolist()
85
+ results = clip_retrieval_client.query(embedding_input=embeddings)
 
 
 
 
 
 
 
 
86
  images = []
87
+ for result in results:
88
+ if len(images) >= n_samples:
89
+ break
90
+ # dowload image
91
+ import requests
92
+ from io import BytesIO
93
+ response = requests.get(result["url"])
94
+ if not response.ok:
95
+ continue
96
+ try:
97
+ bytes = BytesIO(response.content)
98
+ image = Image.open(bytes)
99
+ images.append(image)
100
+ except Exception as e:
101
+ print(e)
102
  return images
103
 
104
  def on_image_load_update_embeddings(image_data):
 
142
  return gr.Text.update('')
143
 
144
  # TODO toggle this to support average or sum
145
+ # final_embedding = final_embedding / num_embeddings
146
 
147
  # normalize embeddings in numpy
148
+ # final_embedding /= np.linalg.norm(final_embedding)
149
 
150
  embeddings_b64 = embedding_to_base64(final_embedding)
151
  return embeddings_b64
 
364
  with gr.Accordion(f"Avergage embeddings in base 64", open=False):
365
  average_embedding_base64 = gr.Textbox(show_label=False)
366
  with gr.Row():
367
+ submit = gr.Button("Search embedding space")
368
  with gr.Row():
369
  with gr.Column(scale=1, min_width=200):
370
  scale = gr.Slider(0, 25, value=3, step=1, label="Guidance scale")
371
  with gr.Column(scale=1, min_width=200):
372
+ n_samples = gr.Slider(1, 16, value=4, step=1, label="Number images")
373
  with gr.Column(scale=1, min_width=200):
374
  steps = gr.Slider(5, 50, value=25, step=5, label="Steps")
375
  with gr.Column(scale=1, min_width=200):