Pedro Cuenca commited on
Commit
ffed138
1 Parent(s): f62b045

Simple skeleton for a streamlit app

Browse files

In order to use it, you need to create a file `.streamlit/secrets.toml`
to define the URL of the BACKEND_SERVER:

```
BACKEND_SERVER="<server url>"
```


Former-commit-id: 4d81cb1c805c903c74b82a5706b3a54ce8a2348b

Files changed (1) hide show
  1. app/app.py +22 -180
app/app.py CHANGED
@@ -1,196 +1,38 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
4
- # Uncomment to run on cpu
5
- #import os
6
- #os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
-
8
  import random
9
-
10
- import jax
11
- import flax.linen as nn
12
- from flax.training.common_utils import shard
13
- from flax.jax_utils import replicate, unreplicate
14
-
15
- from transformers.models.bart.modeling_flax_bart import *
16
- from transformers import BartTokenizer, FlaxBartForConditionalGeneration
17
-
18
-
19
- import requests
20
- from PIL import Image
21
- import numpy as np
22
- import matplotlib.pyplot as plt
23
-
24
-
25
- from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
26
 
27
  import streamlit as st
28
 
29
- st.write("Loading model...")
30
-
31
- # TODO: set those args in a config file
32
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
33
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
34
- BOS_TOKEN_ID = 16384
35
- BASE_MODEL = 'flax-community/dalle-mini'
36
-
37
- class CustomFlaxBartModule(FlaxBartModule):
38
- def setup(self):
39
- # we keep shared to easily load pre-trained weights
40
- self.shared = nn.Embed(
41
- self.config.vocab_size,
42
- self.config.d_model,
43
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
44
- dtype=self.dtype,
45
- )
46
- # a separate embedding is used for the decoder
47
- self.decoder_embed = nn.Embed(
48
- OUTPUT_VOCAB_SIZE,
49
- self.config.d_model,
50
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
51
- dtype=self.dtype,
52
- )
53
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
54
-
55
- # the decoder has a different config
56
- decoder_config = BartConfig(self.config.to_dict())
57
- decoder_config.max_position_embeddings = OUTPUT_LENGTH
58
- decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
59
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
60
-
61
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
62
- def setup(self):
63
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
64
- self.lm_head = nn.Dense(
65
- OUTPUT_VOCAB_SIZE,
66
- use_bias=False,
67
- dtype=self.dtype,
68
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
69
- )
70
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
71
-
72
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
73
- module_class = CustomFlaxBartForConditionalGenerationModule
74
-
75
- # create our model
76
- # FIXME: Save tokenizer to hub so we can load from there
77
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
78
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)
79
- model.config.force_bos_token_to_be_generated = False
80
- model.config.forced_bos_token_id = None
81
- model.config.forced_eos_token_id = None
82
-
83
- vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
84
-
85
- def custom_to_pil(x):
86
- x = np.clip(x, 0., 1.)
87
- x = (255*x).astype(np.uint8)
88
- x = Image.fromarray(x)
89
- if not x.mode == "RGB":
90
- x = x.convert("RGB")
91
- return x
92
-
93
- def generate(input, rng, params):
94
- return model.generate(
95
- **input,
96
- max_length=257,
97
- num_beams=1,
98
- do_sample=True,
99
- prng_key=rng,
100
- eos_token_id=50000,
101
- pad_token_id=50000,
102
- params=params,
103
- )
104
-
105
- def get_images(indices, params):
106
- return vqgan.decode_code(indices, params=params)
107
-
108
- def plot_images(images):
109
- fig = plt.figure(figsize=(40, 20))
110
- columns = 4
111
- rows = 2
112
- plt.subplots_adjust(hspace=0, wspace=0)
113
-
114
- for i in range(1, columns*rows +1):
115
- fig.add_subplot(rows, columns, i)
116
- plt.imshow(images[i-1])
117
- plt.gca().axes.get_yaxis().set_visible(False)
118
- plt.show()
119
-
120
- def stack_reconstructions(images):
121
- w, h = images[0].size[0], images[0].size[1]
122
- img = Image.new("RGB", (len(images)*w, h))
123
- for i, img_ in enumerate(images):
124
- img.paste(img_, (i*w,0))
125
- return img
126
-
127
- p_generate = jax.pmap(generate, "batch")
128
- p_get_images = jax.pmap(get_images, "batch")
129
-
130
- bart_params = replicate(model.params)
131
- vqgan_params = replicate(vqgan.params)
132
-
133
- # ## CLIP Scoring
134
- from transformers import CLIPProcessor, FlaxCLIPModel
135
-
136
- clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
137
- # st.write("FlaxCLIPModel")
138
- # print("Initialize FlaxCLIPModel")
139
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
140
- # st.write("CLIPProcessor")
141
- # print("Initialize CLIPProcessor")
142
-
143
- def hallucinate(prompt, num_images=64):
144
- prompt = [prompt] * jax.device_count()
145
- inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
146
- inputs = shard(inputs)
147
-
148
- all_images = []
149
- for i in range(num_images // jax.device_count()):
150
- key = random.randint(0, 1e7)
151
- rng = jax.random.PRNGKey(key)
152
- rngs = jax.random.split(rng, jax.local_device_count())
153
- indices = p_generate(inputs, rngs, bart_params).sequences
154
- indices = indices[:, :, 1:]
155
-
156
- images = p_get_images(indices, vqgan_params)
157
- images = np.squeeze(np.asarray(images), 1)
158
- for image in images:
159
- all_images.append(custom_to_pil(image))
160
- return all_images
161
-
162
- def clip_top_k(prompt, images, k=8):
163
- inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
164
- outputs = clip(**inputs)
165
- logits = outputs.logits_per_text
166
- scores = np.array(logits[0]).argsort()[-k:][::-1]
167
- return [images[score] for score in scores]
168
-
169
- def captioned_strip(images, caption):
170
- increased_h = 0 if caption is None else 48
171
- w, h = images[0].size[0], images[0].size[1]
172
- img = Image.new("RGB", (len(images)*w, h + increased_h))
173
- for i, img_ in enumerate(images):
174
- img.paste(img_, (i*w, increased_h))
175
-
176
- if caption is not None:
177
- draw = ImageDraw.Draw(img)
178
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
179
- draw.text((20, 3), caption, (255,255,255), font=font)
180
- return img
181
-
182
  # Controls
183
 
184
- num_images = st.sidebar.slider("Candidates to generate", 1, 64, 8, 1)
185
- num_preds = st.sidebar.slider("Best predictions to show", 1, 8, 1, 1)
186
 
 
187
 
188
  prompt = st.text_input("What do you want to see?")
189
 
190
  if prompt != "":
191
  st.write(f"Generating candidates for: {prompt}")
192
- images = hallucinate(prompt, num_images=num_images)
193
- images = clip_top_k(prompt, images, k=num_preds)
194
- predictions_strip = captioned_strip(images, None)
195
 
196
- st.image(predictions_strip)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
 
 
 
 
4
  import random
5
+ from dalle_mini.backend import ServiceError, get_images_from_backend
6
+ from dalle_mini.helpers import captioned_strip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  import streamlit as st
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Controls
11
 
12
+ # num_images = st.sidebar.slider("Candidates to generate", 1, 64, 8, 1)
13
+ # num_preds = st.sidebar.slider("Best predictions to show", 1, 8, 1, 1)
14
 
15
+ st.sidebar.markdown('Visit [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)')
16
 
17
  prompt = st.text_input("What do you want to see?")
18
 
19
  if prompt != "":
20
  st.write(f"Generating candidates for: {prompt}")
 
 
 
21
 
22
+ try:
23
+ backend_url = st.secrets["BACKEND_SERVER"]
24
+ print(f"Getting selections: {prompt}")
25
+ selected = get_images_from_backend(prompt, backend_url)
26
+ preds = captioned_strip(selected, prompt)
27
+ st.image(preds)
28
+ except ServiceError as error:
29
+ st.write(f"Service unavailable, status: {error.status_code}")
30
+ except KeyError:
31
+ st.write("""
32
+ **Error: BACKEND_SERVER unset**
33
+
34
+ Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
35
+ ```
36
+ BACKEND_SERVER="<server url>"
37
+ ```
38
+ """)