boris commited on
Commit
616cf8e
1 Parent(s): badd15c

feat(app): migrate demo to Gradio by @AK391 (#179)

Browse files
README.md CHANGED
@@ -3,8 +3,9 @@ title: DALL·E mini
3
  emoji: 🥑
4
  colorFrom: yellow
5
  colorTo: green
6
- sdk: streamlit
7
- app_file: app/streamlit/app.py
 
8
  pinned: True
9
  license: apache-2.0
10
  ---
 
3
  emoji: 🥑
4
  colorFrom: yellow
5
  colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.0b6
8
+ app_file: app/gradio/app.py
9
  pinned: True
10
  license: apache-2.0
11
  ---
app/gradio/app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import os
4
+
5
+ import gradio as gr
6
+ from backend import get_images_from_backend
7
+
8
+ block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
9
+ backend_url = os.environ["BACKEND_SERVER"] + "/generate"
10
+
11
+
12
+ def infer(prompt):
13
+ response = get_images_from_backend(prompt, backend_url)
14
+ return response["images"]
15
+
16
+
17
+ with block:
18
+ gr.Markdown("<h1><center>DALL·E mini</center></h1>")
19
+ gr.Markdown(
20
+ "DALL·E mini is an AI model that generates images from any prompt you give!"
21
+ )
22
+ with gr.Group():
23
+ with gr.Box():
24
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
25
+
26
+ text = gr.Textbox(
27
+ label="Enter your prompt", show_label=False, max_lines=1
28
+ ).style(
29
+ border=(True, False, True, True),
30
+ margin=False,
31
+ rounded=(True, False, False, True),
32
+ container=False,
33
+ )
34
+ btn = gr.Button("Run").style(
35
+ margin=False,
36
+ rounded=(False, True, True, False),
37
+ )
38
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(
39
+ grid=[3], height="auto"
40
+ )
41
+ btn.click(infer, inputs=text, outputs=gallery)
42
+
43
+ gr.Markdown(
44
+ """___
45
+ <p style='text-align: center'>
46
+ Created by Boris Dayma et al. 2021-2022
47
+ <br/>
48
+ <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
49
+ </p>"""
50
+ )
51
+
52
+
53
+ block.launch()
app/gradio/app_gradio.py DELETED
@@ -1,179 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- # Uncomment to run on cpu
5
- # import os
6
- # os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
-
8
- import random
9
-
10
- import gradio as gr
11
- import jax
12
- import numpy as np
13
- from flax.jax_utils import replicate
14
- from flax.training.common_utils import shard
15
- from PIL import Image, ImageDraw, ImageFont
16
-
17
- # ## CLIP Scoring
18
- from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
19
- from vqgan_jax.modeling_flax_vqgan import VQModel
20
-
21
- from dalle_mini.model import CustomFlaxBartForConditionalGeneration
22
-
23
- DALLE_REPO = "flax-community/dalle-mini"
24
- DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
25
-
26
- VQGAN_REPO = "flax-community/vqgan_f16_16384"
27
- VQGAN_COMMIT_ID = "90cc46addd2dd8f5be21586a9a23e1b95aa506a9"
28
-
29
- tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
30
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(
31
- DALLE_REPO, revision=DALLE_COMMIT_ID
32
- )
33
- vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
34
-
35
-
36
- def captioned_strip(images, caption=None, rows=1):
37
- increased_h = 0 if caption is None else 48
38
- w, h = images[0].size[0], images[0].size[1]
39
- img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
40
- for i, img_ in enumerate(images):
41
- img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
42
-
43
- if caption is not None:
44
- draw = ImageDraw.Draw(img)
45
- font = ImageFont.truetype(
46
- "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
47
- )
48
- draw.text((20, 3), caption, (255, 255, 255), font=font)
49
- return img
50
-
51
-
52
- def custom_to_pil(x):
53
- x = np.clip(x, 0.0, 1.0)
54
- x = (255 * x).astype(np.uint8)
55
- x = Image.fromarray(x)
56
- if not x.mode == "RGB":
57
- x = x.convert("RGB")
58
- return x
59
-
60
-
61
- def generate(input, rng, params):
62
- return model.generate(
63
- **input,
64
- max_length=257,
65
- num_beams=1,
66
- do_sample=True,
67
- prng_key=rng,
68
- eos_token_id=50000,
69
- pad_token_id=50000,
70
- params=params,
71
- )
72
-
73
-
74
- def get_images(indices, params):
75
- return vqgan.decode_code(indices, params=params)
76
-
77
-
78
- p_generate = jax.pmap(generate, "batch")
79
- p_get_images = jax.pmap(get_images, "batch")
80
-
81
- bart_params = replicate(model.params)
82
- vqgan_params = replicate(vqgan.params)
83
-
84
- clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
85
- print("Initialize FlaxCLIPModel")
86
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
87
- print("Initialize CLIPProcessor")
88
-
89
-
90
- def hallucinate(prompt, num_images=64):
91
- prompt = [prompt] * jax.device_count()
92
- inputs = tokenizer(
93
- prompt,
94
- return_tensors="jax",
95
- padding="max_length",
96
- truncation=True,
97
- max_length=128,
98
- ).data
99
- inputs = shard(inputs)
100
-
101
- all_images = []
102
- for i in range(num_images // jax.device_count()):
103
- key = random.randint(0, 1e7)
104
- rng = jax.random.PRNGKey(key)
105
- rngs = jax.random.split(rng, jax.local_device_count())
106
- indices = p_generate(inputs, rngs, bart_params).sequences
107
- indices = indices[:, :, 1:]
108
-
109
- images = p_get_images(indices, vqgan_params)
110
- images = np.squeeze(np.asarray(images), 1)
111
- for image in images:
112
- all_images.append(custom_to_pil(image))
113
- return all_images
114
-
115
-
116
- def clip_top_k(prompt, images, k=8):
117
- inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
118
- outputs = clip(**inputs)
119
- logits = outputs.logits_per_text
120
- scores = np.array(logits[0]).argsort()[-k:][::-1]
121
- return [images[score] for score in scores]
122
-
123
-
124
- def compose_predictions(images, caption=None):
125
- increased_h = 0 if caption is None else 48
126
- w, h = images[0].size[0], images[0].size[1]
127
- img = Image.new("RGB", (len(images) * w, h + increased_h))
128
- for i, img_ in enumerate(images):
129
- img.paste(img_, (i * w, increased_h))
130
-
131
- if caption is not None:
132
- draw = ImageDraw.Draw(img)
133
- font = ImageFont.truetype(
134
- "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
135
- )
136
- draw.text((20, 3), caption, (255, 255, 255), font=font)
137
- return img
138
-
139
-
140
- def top_k_predictions(prompt, num_candidates=32, k=8):
141
- images = hallucinate(prompt, num_images=num_candidates)
142
- images = clip_top_k(prompt, images, k=k)
143
- return images
144
-
145
-
146
- def run_inference(prompt, num_images=32, num_preds=8):
147
- images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
148
- predictions = captioned_strip(images)
149
- output_title = f"""
150
- <b>{prompt}</b>
151
- """
152
- return (output_title, predictions)
153
-
154
-
155
- outputs = [
156
- gr.outputs.HTML(label=""), # To be used as title
157
- gr.outputs.Image(label=""),
158
- ]
159
-
160
- description = """
161
- DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
162
- """
163
- gr.Interface(
164
- run_inference,
165
- inputs=[gr.inputs.Textbox(label="What do you want to see?")],
166
- outputs=outputs,
167
- title="DALL·E mini",
168
- description=description,
169
- article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
170
- layout="vertical",
171
- theme="huggingface",
172
- examples=[
173
- ["an armchair in the shape of an avocado"],
174
- ["snowy mountains by the sea"],
175
- ],
176
- allow_flagging=False,
177
- live=False,
178
- # server_port=8999
179
- ).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/gradio/backend.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Client requests to Dalle-Mini Backend server
2
+
3
+ import base64
4
+ from io import BytesIO
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+
10
+ class ServiceError(Exception):
11
+ def __init__(self, status_code):
12
+ self.status_code = status_code
13
+
14
+
15
+ def get_images_from_backend(prompt, backend_url):
16
+ r = requests.post(backend_url, json={"prompt": prompt})
17
+ if r.status_code == 200:
18
+ json = r.json()
19
+ images = json["images"]
20
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
21
+ version = json.get("version", "unknown")
22
+ return {"images": images, "version": version}
23
+ else:
24
+ raise ServiceError(r.status_code)
25
+
26
+
27
+ def get_model_version(url):
28
+ r = requests.get(url)
29
+ if r.status_code == 200:
30
+ version = r.json()["version"]
31
+ return version
32
+ else:
33
+ raise ServiceError(r.status_code)
app/gradio/requirements.txt DELETED
@@ -1,4 +0,0 @@
1
- # Requirements for huggingface spaces
2
- gradio>=2.2.3
3
- flax
4
- transformers
 
 
 
 
 
app/streamlit/app.py CHANGED
@@ -1,8 +1,6 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
4
- from datetime import datetime
5
-
6
  import streamlit as st
7
  from backend import ServiceError, get_images_from_backend
8
 
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
 
 
4
  import streamlit as st
5
  from backend import ServiceError, get_images_from_backend
6