Pedro Cuenca commited on
Commit
cb2ac60
1 Parent(s): 830d7a2

Barebones demo app for local testing.

Browse files

Former-commit-id: 656c5c2ff48cb8d07d100eee7b04ace061556012

Files changed (2) hide show
  1. app/app.py +176 -0
  2. app/requirements.txt +11 -0
app/app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # Uncomment to run on cpu
5
+ #import os
6
+ #os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
+
8
+ import random
9
+
10
+ import jax
11
+ import flax.linen as nn
12
+ from flax.training.common_utils import shard
13
+ from flax.jax_utils import replicate, unreplicate
14
+
15
+ from transformers.models.bart.modeling_flax_bart import *
16
+ from transformers import BartTokenizer, FlaxBartForConditionalGeneration
17
+
18
+
19
+ import requests
20
+ from PIL import Image
21
+ import numpy as np
22
+ import matplotlib.pyplot as plt
23
+
24
+
25
+ from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
26
+
27
+ import streamlit as st
28
+
29
+ st.write("List GPU Device", jax.devices("gpu"))
30
+ st.write("Loading model...")
31
+
32
+ # TODO: set those args in a config file
33
+ OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
34
+ OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
35
+ BOS_TOKEN_ID = 16384
36
+ BASE_MODEL = 'flax-community/dalle-mini'
37
+
38
+ class CustomFlaxBartModule(FlaxBartModule):
39
+ def setup(self):
40
+ # we keep shared to easily load pre-trained weights
41
+ self.shared = nn.Embed(
42
+ self.config.vocab_size,
43
+ self.config.d_model,
44
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
45
+ dtype=self.dtype,
46
+ )
47
+ # a separate embedding is used for the decoder
48
+ self.decoder_embed = nn.Embed(
49
+ OUTPUT_VOCAB_SIZE,
50
+ self.config.d_model,
51
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
52
+ dtype=self.dtype,
53
+ )
54
+ self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
55
+
56
+ # the decoder has a different config
57
+ decoder_config = BartConfig(self.config.to_dict())
58
+ decoder_config.max_position_embeddings = OUTPUT_LENGTH
59
+ decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
60
+ self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
61
+
62
+ class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
63
+ def setup(self):
64
+ self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
65
+ self.lm_head = nn.Dense(
66
+ OUTPUT_VOCAB_SIZE,
67
+ use_bias=False,
68
+ dtype=self.dtype,
69
+ kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
70
+ )
71
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
72
+
73
+ class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
74
+ module_class = CustomFlaxBartForConditionalGenerationModule
75
+
76
+ # create our model
77
+ # FIXME: Save tokenizer to hub so we can load from there
78
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
79
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)
80
+ model.config.force_bos_token_to_be_generated = False
81
+ model.config.forced_bos_token_id = None
82
+ model.config.forced_eos_token_id = None
83
+
84
+ vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
85
+ st.write("VQModel")
86
+ print("Initialize VqModel")
87
+
88
+ def custom_to_pil(x):
89
+ x = np.clip(x, 0., 1.)
90
+ x = (255*x).astype(np.uint8)
91
+ x = Image.fromarray(x)
92
+ if not x.mode == "RGB":
93
+ x = x.convert("RGB")
94
+ return x
95
+
96
+ def generate(input, rng, params):
97
+ return model.generate(
98
+ **input,
99
+ max_length=257,
100
+ num_beams=1,
101
+ do_sample=True,
102
+ prng_key=rng,
103
+ eos_token_id=50000,
104
+ pad_token_id=50000,
105
+ params=params,
106
+ )
107
+
108
+ def get_images(indices, params):
109
+ return vqgan.decode_code(indices, params=params)
110
+
111
+ def plot_images(images):
112
+ fig = plt.figure(figsize=(40, 20))
113
+ columns = 4
114
+ rows = 2
115
+ plt.subplots_adjust(hspace=0, wspace=0)
116
+
117
+ for i in range(1, columns*rows +1):
118
+ fig.add_subplot(rows, columns, i)
119
+ plt.imshow(images[i-1])
120
+ plt.gca().axes.get_yaxis().set_visible(False)
121
+ plt.show()
122
+
123
+ def stack_reconstructions(images):
124
+ w, h = images[0].size[0], images[0].size[1]
125
+ img = Image.new("RGB", (len(images)*w, h))
126
+ for i, img_ in enumerate(images):
127
+ img.paste(img_, (i*w,0))
128
+ return img
129
+
130
+ p_generate = jax.pmap(generate, "batch")
131
+ p_get_images = jax.pmap(get_images, "batch")
132
+
133
+ bart_params = replicate(model.params)
134
+ vqgan_params = replicate(vqgan.params)
135
+
136
+ # ## CLIP Scoring
137
+ from transformers import CLIPProcessor, FlaxCLIPModel
138
+
139
+ clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
140
+ st.write("FlaxCLIPModel")
141
+ print("Initialize FlaxCLIPModel")
142
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
143
+ st.write("CLIPProcessor")
144
+ print("Initialize CLIPProcessor")
145
+
146
+ def hallucinate(prompt, num_images=64):
147
+ prompt = [prompt] * jax.device_count()
148
+ inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
149
+ inputs = shard(inputs)
150
+
151
+ all_images = []
152
+ for i in range(num_images // jax.device_count()):
153
+ key = random.randint(0, 1e7)
154
+ rng = jax.random.PRNGKey(key)
155
+ rngs = jax.random.split(rng, jax.local_device_count())
156
+ indices = p_generate(inputs, rngs, bart_params).sequences
157
+ indices = indices[:, :, 1:]
158
+
159
+ images = p_get_images(indices, vqgan_params)
160
+ images = np.squeeze(np.asarray(images), 1)
161
+ for image in images:
162
+ all_images.append(custom_to_pil(image))
163
+ return all_images
164
+
165
+ def clip_top_k(prompt, images, k=8):
166
+ inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
167
+ outputs = clip(**inputs)
168
+ logits = outputs.logits_per_text
169
+ scores = np.array(logits[0]).argsort()[-k:][::-1]
170
+ return [images[score] for score in scores]
171
+
172
+ prompt = st.text_input("Input prompt", "rice fields by the mediterranean coast")
173
+ st.write(f"Generating candidates for: {prompt}")
174
+
175
+ images = hallucinate(prompt, num_images=1)
176
+ st.image(images[0])
app/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements for huggingface spaces
2
+ -f https://storage.googleapis.com/jax-releases/jax_releases.html
3
+ jaxlib
4
+ -f https://storage.googleapis.com/jax-releases/jax_releases.html
5
+ jax[cuda111]
6
+
7
+ flax
8
+ transformers
9
+
10
+ # To download the model. We could use the model hub.
11
+ wandb