najoungkim commited on
Commit
87468ed
1 Parent(s): 58cacd2

Initial commit

Browse files
Files changed (4) hide show
  1. LiberationMono-Bold.ttf +0 -0
  2. README.md +4 -4
  3. app.py +204 -0
  4. requirements.txt +13 -0
LiberationMono-Bold.ttf ADDED
Binary file (302 kB). View file
 
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Round Trip Dalle Mini
3
- emoji: 💻
4
- colorFrom: yellow
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.0.15
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
  title: Round Trip Dalle Mini
3
+ emoji: 🔁
4
+ colorFrom: pink
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.0.14
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import os
5
+ # Uncomment to run on cpu
6
+ # os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
+ os.environ["WANDB_DISABLED"] = "true"
8
+ os.environ['WANDB_SILENT']="true"
9
+
10
+ import random
11
+ import re
12
+ import torch
13
+
14
+ import gradio as gr
15
+ import jax
16
+ import jax.numpy as jnp
17
+ import numpy as np
18
+ from flax.jax_utils import replicate
19
+ from flax.training.common_utils import shard, shard_prng_key
20
+ from PIL import Image, ImageDraw, ImageFont
21
+
22
+ from functools import partial
23
+
24
+ from transformers import CLIPProcessor, FlaxCLIPModel, AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
25
+ from dalle_mini import DalleBart, DalleBartProcessor
26
+ from vqgan_jax.modeling_flax_vqgan import VQModel
27
+
28
+
29
+ DALLE_REPO = "dalle-mini/dalle-mini/mini-1:v0"
30
+ DALLE_COMMIT_ID = None
31
+
32
+ VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
33
+ VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
34
+
35
+ model, params = DalleBart.from_pretrained(
36
+ DALLE_REPO, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
37
+ )
38
+ vqgan, vqgan_params = VQModel.from_pretrained(
39
+ VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
40
+ )
41
+
42
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
43
+
44
+ encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
45
+ decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
46
+ model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
47
+ feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
48
+ tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
49
+ viz_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
50
+
51
+
52
+ def captioned_strip(images, caption=None, rows=1):
53
+ increased_h = 0 if caption is None else 24
54
+ w, h = images[0].size[0], images[0].size[1]
55
+ img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
56
+ for i, img_ in enumerate(images):
57
+ img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
58
+
59
+ if caption is not None:
60
+ draw = ImageDraw.Draw(img)
61
+ font = ImageFont.truetype(
62
+ "LiberationMono-Bold.ttf", 7
63
+ )
64
+ draw.text((20, 3), caption, (255, 255, 255), font=font)
65
+ return img
66
+
67
+
68
+ def get_images(indices, params):
69
+ return vqgan.decode_code(indices, params=params)
70
+
71
+
72
+ def predict_caption(image, max_length=128, num_beams=4):
73
+ image = image.convert('RGB')
74
+ image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
75
+ clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
76
+ caption_ids = viz_model.generate(image, max_length = max_length)[0]
77
+ caption_text = clean_text(tokenizer.decode(caption_ids))
78
+ return caption_text
79
+
80
+
81
+ # model inference
82
+ @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
83
+ def p_generate(
84
+ tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
85
+ ):
86
+ return model.generate(
87
+ **tokenized_prompt,
88
+ prng_key=key,
89
+ params=params,
90
+ top_k=top_k,
91
+ top_p=top_p,
92
+ temperature=temperature,
93
+ condition_scale=condition_scale,
94
+ )
95
+
96
+
97
+ # decode image
98
+ @partial(jax.pmap, axis_name="batch")
99
+ def p_decode(indices, params):
100
+ return vqgan.decode_code(indices, params=params)
101
+
102
+ p_get_images = jax.pmap(get_images, "batch")
103
+
104
+ params = replicate(params)
105
+ vqgan_params = replicate(vqgan_params)
106
+
107
+ processor = DalleBartProcessor.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
108
+ print("Initialized DalleBartProcessor")
109
+ clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
110
+ print("Initialized FlaxCLIPModel")
111
+
112
+
113
+ def hallucinate(prompt, num_images=8):
114
+ gen_top_k = None
115
+ gen_top_p = None
116
+ temperature = None
117
+ cond_scale = 10.0
118
+
119
+ print(f"Prompts: {prompt}")
120
+ prompt = [prompt] * jax.device_count()
121
+ inputs = processor(prompt)
122
+ inputs = replicate(inputs)
123
+
124
+ # create a random key
125
+ seed = random.randint(0, 2**32 - 1)
126
+ key = jax.random.PRNGKey(seed)
127
+
128
+ images = []
129
+ for i in range(max(num_images // jax.device_count(), 1)):
130
+ key, subkey = jax.random.split(key)
131
+ encoded_images = p_generate(
132
+ inputs,
133
+ shard_prng_key(subkey),
134
+ params,
135
+ gen_top_k,
136
+ gen_top_p,
137
+ temperature,
138
+ cond_scale,
139
+ )
140
+ print(f"Encoded image {i}")
141
+ # remove BOS
142
+ encoded_images = encoded_images.sequences[..., 1:]
143
+ # decode images
144
+ decoded_images = p_decode(encoded_images, vqgan_params)
145
+ decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
146
+ for decoded_img in decoded_images:
147
+ img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
148
+ images.append(img)
149
+
150
+ print(f"Finished decoding image {i}")
151
+ return images
152
+
153
+
154
+ def run_inference(prompt, num_roundtrips=3, num_images=1):
155
+ outputs = []
156
+ for i in range(int(num_roundtrips)):
157
+ images = hallucinate(prompt, num_images=num_images)
158
+ image = images[0]
159
+ print("Generated image")
160
+ caption = predict_caption(image)
161
+ print(f"Predicted caption: {caption}")
162
+
163
+ output_title = f"""
164
+ <font size="+3">
165
+ <b>[Roundtrip {i}]</b><br>
166
+ Prompt: {prompt}<br>
167
+ 🥑 :<br></font>"""
168
+ output_caption = f"""
169
+ <font size="+3">
170
+ 🤖💬 : {caption}<br>
171
+ </font>
172
+ """
173
+ outputs.append(output_title)
174
+ outputs.append(image)
175
+ outputs.append(output_caption)
176
+ prompt = caption
177
+
178
+ return outputs
179
+
180
+
181
+ inputs = gr.inputs.Textbox(label="What prompt do you want to start with?", default="a poster of cookie monster live action")
182
+ # num_roundtrips = gr.inputs.Number(default=2, label="How many roundtrips?")
183
+ num_roundtrips = 3
184
+ outputs = []
185
+ for _ in range(int(num_roundtrips)):
186
+ outputs.append(gr.outputs.HTML(label=""))
187
+ outputs.append(gr.Image(label=""))
188
+ outputs.append(gr.outputs.HTML(label=""))
189
+
190
+ description = """
191
+ Round trip DALL·E-mini iterates between DALL·E generation and image captioning, inspired by round trip translation!
192
+ """
193
+ article = "<p style='text-align: center'>Put together by: Najoung Kim | Dall-E Mini code from flax-community/dalle-mini | Caption code from SRDdev/Image-Caption</p>"
194
+
195
+ gr.Interface(
196
+ fn=run_inference,
197
+ inputs=[inputs],
198
+ outputs=outputs,
199
+ title="Round Trip DALL·E mini 🥑🔁🤖💬",
200
+ description=description,
201
+ article=article,
202
+ theme="default",
203
+ css = ".output-image, .input-image, .image-preview {height: 256px !important} "
204
+ ).launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=2.2.3
2
+ flax
3
+ transformers
4
+ einops
5
+ unidecode
6
+ ftfy
7
+ emoji
8
+ pillow
9
+ jax
10
+ flax
11
+ torch
12
+ git+https://github.com/patil-suraj/vqgan-jax.git
13
+ git+https://github.com/borisdayma/dalle-mini.git