johnpaulbin commited on
Commit
3c0bb9c
1 Parent(s): c5d3d6b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ # Uncomment to run on cpu
4
+ #import os
5
+ #os.environ["JAX_PLATFORM_NAME"] = "cpu"
6
+ import random
7
+ import jax
8
+ import flax.linen as nn
9
+ from flax.training.common_utils import shard
10
+ from flax.jax_utils import replicate, unreplicate
11
+ from transformers import BartTokenizer, FlaxBartForConditionalGeneration
12
+ from PIL import Image
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ from vqgan_jax.modeling_flax_vqgan import VQModel
16
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
17
+ # ## CLIP Scoring
18
+ from transformers import CLIPProcessor, FlaxCLIPModel
19
+ import gradio as gr
20
+ from dalle_mini.helpers import captioned_strip
21
+ DALLE_REPO = 'flax-community/dalle-mini'
22
+ DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
23
+ VQGAN_REPO = 'flax-community/vqgan_f16_16384'
24
+ VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'
25
+ tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
26
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
27
+ vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
28
+ def custom_to_pil(x):
29
+ x = np.clip(x, 0., 1.)
30
+ x = (255*x).astype(np.uint8)
31
+ x = Image.fromarray(x)
32
+ if not x.mode == "RGB":
33
+ x = x.convert("RGB")
34
+ return x
35
+ def generate(input, rng, params):
36
+ return model.generate(
37
+ **input,
38
+ max_length=257,
39
+ num_beams=1,
40
+ do_sample=True,
41
+ prng_key=rng,
42
+ eos_token_id=50000,
43
+ pad_token_id=50000,
44
+ params=params,
45
+ )
46
+ def get_images(indices, params):
47
+ return vqgan.decode_code(indices, params=params)
48
+ p_generate = jax.pmap(generate, "batch")
49
+ p_get_images = jax.pmap(get_images, "batch")
50
+ bart_params = replicate(model.params)
51
+ vqgan_params = replicate(vqgan.params)
52
+ clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
53
+ print("Initialize FlaxCLIPModel")
54
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
55
+ print("Initialize CLIPProcessor")
56
+ def hallucinate(prompt, num_images=64):
57
+ prompt = [prompt] * jax.device_count()
58
+ inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
59
+ inputs = shard(inputs)
60
+ all_images = []
61
+ for i in range(num_images // jax.device_count()):
62
+ key = random.randint(0, 1e7)
63
+ rng = jax.random.PRNGKey(key)
64
+ rngs = jax.random.split(rng, jax.local_device_count())
65
+ indices = p_generate(inputs, rngs, bart_params).sequences
66
+ indices = indices[:, :, 1:]
67
+ images = p_get_images(indices, vqgan_params)
68
+ images = np.squeeze(np.asarray(images), 1)
69
+ for image in images:
70
+ all_images.append(custom_to_pil(image))
71
+ return all_images
72
+ def clip_top_k(prompt, images, k=8):
73
+ inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
74
+ outputs = clip(**inputs)
75
+ logits = outputs.logits_per_text
76
+ scores = np.array(logits[0]).argsort()[-k:][::-1]
77
+ return [images[score] for score in scores]
78
+ def compose_predictions(images, caption=None):
79
+ increased_h = 0 if caption is None else 48
80
+ w, h = images[0].size[0], images[0].size[1]
81
+ img = Image.new("RGB", (len(images)*w, h + increased_h))
82
+ for i, img_ in enumerate(images):
83
+ img.paste(img_, (i*w, increased_h))
84
+ if caption is not None:
85
+ draw = ImageDraw.Draw(img)
86
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
87
+ draw.text((20, 3), caption, (255,255,255), font=font)
88
+ return img
89
+ def top_k_predictions(prompt, num_candidates=32, k=8):
90
+ images = hallucinate(prompt, num_images=num_candidates)
91
+ images = clip_top_k(prompt, images, k=k)
92
+ return images
93
+ def run_inference(prompt, num_images=32, num_preds=8):
94
+ images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
95
+ predictions = captioned_strip(images)
96
+ output_title = f"""
97
+ <b>{prompt}</b>
98
+ """
99
+ return (output_title, predictions)
100
+ outputs = [
101
+ gr.outputs.HTML(label=""), # To be used as title
102
+ gr.outputs.Image(label=''),
103
+ ]
104
+ description = """
105
+ DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
106
+ """
107
+ gr.Interface(run_inference,
108
+ inputs=[gr.inputs.Textbox(label='What do you want to see?')],
109
+ outputs=outputs,
110
+ title='DALL·E mini',
111
+ description=description,
112
+ article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
113
+ layout='vertical',
114
+ theme='huggingface',
115
+ examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
116
+ allow_flagging=False,
117
+ live=False,
118
+ # server_port=8999
119
+ ).launch(share=True)