boris commited on
Commit
0ca6514
1 Parent(s): d4e833e

refactor: captioned_strip used only in gradio

Browse files
Files changed (2) hide show
  1. app/gradio/app_gradio.py +67 -28
  2. dalle_mini/helpers.py +0 -14
app/gradio/app_gradio.py CHANGED
@@ -2,21 +2,20 @@
2
  # coding: utf-8
3
 
4
  # Uncomment to run on cpu
5
- #import os
6
- #os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
 
8
  import random
9
 
10
  import jax
11
  import flax.linen as nn
12
  from flax.training.common_utils import shard
13
- from flax.jax_utils import replicate, unreplicate
14
 
15
- from transformers import BartTokenizer, FlaxBartForConditionalGeneration
16
 
17
- from PIL import Image
18
  import numpy as np
19
- import matplotlib.pyplot as plt
20
 
21
  from vqgan_jax.modeling_flax_vqgan import VQModel
22
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
@@ -26,27 +25,47 @@ from transformers import CLIPProcessor, FlaxCLIPModel
26
 
27
  import gradio as gr
28
 
29
- from dalle_mini.helpers import captioned_strip
30
 
31
 
32
- DALLE_REPO = 'flax-community/dalle-mini'
33
- DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
34
 
35
- VQGAN_REPO = 'flax-community/vqgan_f16_16384'
36
- VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'
37
 
38
  tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
39
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
 
 
40
  vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def custom_to_pil(x):
43
- x = np.clip(x, 0., 1.)
44
- x = (255*x).astype(np.uint8)
45
  x = Image.fromarray(x)
46
  if not x.mode == "RGB":
47
  x = x.convert("RGB")
48
  return x
49
 
 
50
  def generate(input, rng, params):
51
  return model.generate(
52
  **input,
@@ -59,9 +78,11 @@ def generate(input, rng, params):
59
  params=params,
60
  )
61
 
 
62
  def get_images(indices, params):
63
  return vqgan.decode_code(indices, params=params)
64
 
 
65
  p_generate = jax.pmap(generate, "batch")
66
  p_get_images = jax.pmap(get_images, "batch")
67
 
@@ -73,9 +94,16 @@ print("Initialize FlaxCLIPModel")
73
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
74
  print("Initialize CLIPProcessor")
75
 
 
76
  def hallucinate(prompt, num_images=64):
77
  prompt = [prompt] * jax.device_count()
78
- inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
 
 
 
 
 
 
79
  inputs = shard(inputs)
80
 
81
  all_images = []
@@ -92,6 +120,7 @@ def hallucinate(prompt, num_images=64):
92
  all_images.append(custom_to_pil(image))
93
  return all_images
94
 
 
95
  def clip_top_k(prompt, images, k=8):
96
  inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
97
  outputs = clip(**inputs)
@@ -99,24 +128,29 @@ def clip_top_k(prompt, images, k=8):
99
  scores = np.array(logits[0]).argsort()[-k:][::-1]
100
  return [images[score] for score in scores]
101
 
 
102
  def compose_predictions(images, caption=None):
103
  increased_h = 0 if caption is None else 48
104
  w, h = images[0].size[0], images[0].size[1]
105
- img = Image.new("RGB", (len(images)*w, h + increased_h))
106
  for i, img_ in enumerate(images):
107
- img.paste(img_, (i*w, increased_h))
108
 
109
  if caption is not None:
110
  draw = ImageDraw.Draw(img)
111
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
112
- draw.text((20, 3), caption, (255,255,255), font=font)
 
 
113
  return img
114
 
 
115
  def top_k_predictions(prompt, num_candidates=32, k=8):
116
  images = hallucinate(prompt, num_images=num_candidates)
117
  images = clip_top_k(prompt, images, k=k)
118
  return images
119
 
 
120
  def run_inference(prompt, num_images=32, num_preds=8):
121
  images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
122
  predictions = captioned_strip(images)
@@ -125,23 +159,28 @@ def run_inference(prompt, num_images=32, num_preds=8):
125
  """
126
  return (output_title, predictions)
127
 
 
128
  outputs = [
129
- gr.outputs.HTML(label=""), # To be used as title
130
- gr.outputs.Image(label=''),
131
  ]
132
 
133
  description = """
134
  DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
135
  """
136
- gr.Interface(run_inference,
137
- inputs=[gr.inputs.Textbox(label='What do you want to see?')],
138
- outputs=outputs,
139
- title='DALL·E mini',
 
140
  description=description,
141
  article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
142
- layout='vertical',
143
- theme='huggingface',
144
- examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
 
 
 
145
  allow_flagging=False,
146
  live=False,
147
  # server_port=8999
 
2
  # coding: utf-8
3
 
4
  # Uncomment to run on cpu
5
+ # import os
6
+ # os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
 
8
  import random
9
 
10
  import jax
11
  import flax.linen as nn
12
  from flax.training.common_utils import shard
13
+ from flax.jax_utils import replicate
14
 
15
+ from transformers import BartTokenizer
16
 
17
+ from PIL import Image, ImageDraw, ImageFont
18
  import numpy as np
 
19
 
20
  from vqgan_jax.modeling_flax_vqgan import VQModel
21
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
 
25
 
26
  import gradio as gr
27
 
28
+ from PIL import Image, ImageDraw, ImageFont
29
 
30
 
31
+ DALLE_REPO = "flax-community/dalle-mini"
32
+ DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
33
 
34
+ VQGAN_REPO = "flax-community/vqgan_f16_16384"
35
+ VQGAN_COMMIT_ID = "90cc46addd2dd8f5be21586a9a23e1b95aa506a9"
36
 
37
  tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
38
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(
39
+ DALLE_REPO, revision=DALLE_COMMIT_ID
40
+ )
41
  vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
42
 
43
+
44
+ def captioned_strip(images, caption=None, rows=1):
45
+ increased_h = 0 if caption is None else 48
46
+ w, h = images[0].size[0], images[0].size[1]
47
+ img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
48
+ for i, img_ in enumerate(images):
49
+ img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
50
+
51
+ if caption is not None:
52
+ draw = ImageDraw.Draw(img)
53
+ font = ImageFont.truetype(
54
+ "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
55
+ )
56
+ draw.text((20, 3), caption, (255, 255, 255), font=font)
57
+ return img
58
+
59
+
60
  def custom_to_pil(x):
61
+ x = np.clip(x, 0.0, 1.0)
62
+ x = (255 * x).astype(np.uint8)
63
  x = Image.fromarray(x)
64
  if not x.mode == "RGB":
65
  x = x.convert("RGB")
66
  return x
67
 
68
+
69
  def generate(input, rng, params):
70
  return model.generate(
71
  **input,
 
78
  params=params,
79
  )
80
 
81
+
82
  def get_images(indices, params):
83
  return vqgan.decode_code(indices, params=params)
84
 
85
+
86
  p_generate = jax.pmap(generate, "batch")
87
  p_get_images = jax.pmap(get_images, "batch")
88
 
 
94
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
95
  print("Initialize CLIPProcessor")
96
 
97
+
98
  def hallucinate(prompt, num_images=64):
99
  prompt = [prompt] * jax.device_count()
100
+ inputs = tokenizer(
101
+ prompt,
102
+ return_tensors="jax",
103
+ padding="max_length",
104
+ truncation=True,
105
+ max_length=128,
106
+ ).data
107
  inputs = shard(inputs)
108
 
109
  all_images = []
 
120
  all_images.append(custom_to_pil(image))
121
  return all_images
122
 
123
+
124
  def clip_top_k(prompt, images, k=8):
125
  inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
126
  outputs = clip(**inputs)
 
128
  scores = np.array(logits[0]).argsort()[-k:][::-1]
129
  return [images[score] for score in scores]
130
 
131
+
132
  def compose_predictions(images, caption=None):
133
  increased_h = 0 if caption is None else 48
134
  w, h = images[0].size[0], images[0].size[1]
135
+ img = Image.new("RGB", (len(images) * w, h + increased_h))
136
  for i, img_ in enumerate(images):
137
+ img.paste(img_, (i * w, increased_h))
138
 
139
  if caption is not None:
140
  draw = ImageDraw.Draw(img)
141
+ font = ImageFont.truetype(
142
+ "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
143
+ )
144
+ draw.text((20, 3), caption, (255, 255, 255), font=font)
145
  return img
146
 
147
+
148
  def top_k_predictions(prompt, num_candidates=32, k=8):
149
  images = hallucinate(prompt, num_images=num_candidates)
150
  images = clip_top_k(prompt, images, k=k)
151
  return images
152
 
153
+
154
  def run_inference(prompt, num_images=32, num_preds=8):
155
  images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
156
  predictions = captioned_strip(images)
 
159
  """
160
  return (output_title, predictions)
161
 
162
+
163
  outputs = [
164
+ gr.outputs.HTML(label=""), # To be used as title
165
+ gr.outputs.Image(label=""),
166
  ]
167
 
168
  description = """
169
  DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
170
  """
171
+ gr.Interface(
172
+ run_inference,
173
+ inputs=[gr.inputs.Textbox(label="What do you want to see?")],
174
+ outputs=outputs,
175
+ title="DALL·E mini",
176
  description=description,
177
  article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
178
+ layout="vertical",
179
+ theme="huggingface",
180
+ examples=[
181
+ ["an armchair in the shape of an avocado"],
182
+ ["snowy mountains by the sea"],
183
+ ],
184
  allow_flagging=False,
185
  live=False,
186
  # server_port=8999
dalle_mini/helpers.py DELETED
@@ -1,14 +0,0 @@
1
- from PIL import Image, ImageDraw, ImageFont
2
-
3
- def captioned_strip(images, caption=None, rows=1):
4
- increased_h = 0 if caption is None else 48
5
- w, h = images[0].size[0], images[0].size[1]
6
- img = Image.new("RGB", (len(images)*w//rows, h*rows + increased_h))
7
- for i, img_ in enumerate(images):
8
- img.paste(img_, (i//rows*w, increased_h + (i % rows) * h))
9
-
10
- if caption is not None:
11
- draw = ImageDraw.Draw(img)
12
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
13
- draw.text((20, 3), caption, (255,255,255), font=font)
14
- return img