Pedro Cuenca commited on
Commit
85eab14
·
1 Parent(s): 8944cc5

Integrate current UI in demo app.

Browse files

Note that the port number has been removed. I suppose hf's Spaces will
forward requests to the default service port.


Former-commit-id: acb4488ee9887246a28a5c2358bafbda0e29355d

Files changed (2) hide show
  1. app/app_gradio.py +49 -13
  2. app/ui_gradio.py +2 -2
app/app_gradio.py CHANGED
@@ -163,7 +163,7 @@ def clip_top_k(prompt, images, k=8):
163
  scores = np.array(logits[0]).argsort()[-k:][::-1]
164
  return [images[score] for score in scores]
165
 
166
- def captioned_strip(images, caption):
167
  increased_h = 0 if caption is None else 48
168
  w, h = images[0].size[0], images[0].size[1]
169
  img = Image.new("RGB", (len(images)*w, h + increased_h))
@@ -176,19 +176,55 @@ def captioned_strip(images, caption):
176
  draw.text((20, 3), caption, (255,255,255), font=font)
177
  return img
178
 
179
- def run_inference(prompt, num_images=64, num_preds=8):
180
- images = hallucinate(prompt, num_images=num_images)
181
- images = clip_top_k(prompt, images, k=num_preds)
182
- predictions_strip = captioned_strip(images, None)
183
- return predictions_strip
184
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  gr.Interface(run_inference,
186
  inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
187
- outputs=gr.outputs.Image(label='Generated image'),
188
- title='DALLE-mini - HuggingFace Community Week',
189
- description='This is a demo of the DALLE-mini model trained with Jax/Flax on TPU v3-8s during the HuggingFace Community Week',
190
- article="<p style='text-align: center'> DALLE-mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
191
  layout='vertical',
192
  theme='huggingface',
193
- examples=[['an armchair in the shape of an avocado']],
194
- server_port=8999).launch(share=True)
 
 
 
 
163
  scores = np.array(logits[0]).argsort()[-k:][::-1]
164
  return [images[score] for score in scores]
165
 
166
+ def compose_predictions(images, caption=None):
167
  increased_h = 0 if caption is None else 48
168
  w, h = images[0].size[0], images[0].size[1]
169
  img = Image.new("RGB", (len(images)*w, h + increased_h))
 
176
  draw.text((20, 3), caption, (255,255,255), font=font)
177
  return img
178
 
179
+ def top_k_predictions(prompt, num_candidates=32, k=8):
180
+ images = hallucinate(prompt, num_images=num_candidates)
181
+ images = clip_top_k(prompt, images, k=k)
182
+ return images
183
+
184
+ def run_inference(prompt, num_images=32, num_preds=8):
185
+ images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
186
+ predictions = compose_predictions(images)
187
+ output_title = f"""
188
+ <p style="font-size:22px; font-style:bold">Best predictions</p>
189
+ <p>We asked our model to generate 32 candidates for your prompt:</p>
190
+
191
+ <pre>
192
+
193
+ <b>{prompt}</b>
194
+ </pre>
195
+ <p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
196
+ similarity of the text and the image representations.</p>
197
+
198
+ <p>This is the result:</p>
199
+ """
200
+ output_description = """
201
+ <p>Read more about the process <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">in our report</a>.<p>
202
+ <p style='text-align: center'>Created with <a href="https://github.com/borisdayma/dalle-mini">DALLE·mini</a></p>
203
+ """
204
+ return (output_title, predictions, output_description)
205
+
206
+ outputs = [
207
+ gr.outputs.HTML(label=""), # To be used as title
208
+ gr.outputs.Image(label=''),
209
+ gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
210
+ ]
211
+
212
+ description = """
213
+ Welcome to our demo of DALL·E-mini. This project was created on TPU v3-8s during the 🤗 Flax / JAX Community Week.
214
+ It reproduces the essential characteristics of OpenAI's DALL·E, at a fraction of the size.
215
+
216
+ Please, write what you would like the model to generate, or select one of the examples below.
217
+ """
218
  gr.Interface(run_inference,
219
  inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
220
+ outputs=outputs,
221
+ title='DALL·E mini',
222
+ description=description,
223
+ article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
224
  layout='vertical',
225
  theme='huggingface',
226
+ examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
227
+ allow_flagging=False,
228
+ live=False,
229
+ # server_port=8999
230
+ ).launch()
app/ui_gradio.py CHANGED
@@ -51,8 +51,8 @@ def run_inference(prompt, num_images=32, num_preds=8):
51
 
52
  <b>{prompt}</b>
53
  </pre>
54
- <p>We then used a pre-trained CLIP model to score them according to the
55
- similarity of their text and image representations.</p>
56
 
57
  <p>This is the result:</p>
58
  """
 
51
 
52
  <b>{prompt}</b>
53
  </pre>
54
+ <p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
55
+ similarity of the text and the image representations.</p>
56
 
57
  <p>This is the result:</p>
58
  """