dalle-mini / app /app.py
Pedro Cuenca
Add a couple of sliders and prevent generating without a prompt
0e8338d
raw
history blame
6.59 kB
#!/usr/bin/env python
# coding: utf-8
# Uncomment to run on cpu
#import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
import random
import jax
import flax.linen as nn
from flax.training.common_utils import shard
from flax.jax_utils import replicate, unreplicate
from transformers.models.bart.modeling_flax_bart import *
from transformers import BartTokenizer, FlaxBartForConditionalGeneration
import requests
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
import streamlit as st
st.write("Loading model...")
# TODO: set those args in a config file
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
BOS_TOKEN_ID = 16384
BASE_MODEL = 'flax-community/dalle-mini'
class CustomFlaxBartModule(FlaxBartModule):
def setup(self):
# we keep shared to easily load pre-trained weights
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
# a separate embedding is used for the decoder
self.decoder_embed = nn.Embed(
OUTPUT_VOCAB_SIZE,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
# the decoder has a different config
decoder_config = BartConfig(self.config.to_dict())
decoder_config.max_position_embeddings = OUTPUT_LENGTH
decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
def setup(self):
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
OUTPUT_VOCAB_SIZE,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
module_class = CustomFlaxBartForConditionalGenerationModule
# create our model
# FIXME: Save tokenizer to hub so we can load from there
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
model = CustomFlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)
model.config.force_bos_token_to_be_generated = False
model.config.forced_bos_token_id = None
model.config.forced_eos_token_id = None
vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
def custom_to_pil(x):
x = np.clip(x, 0., 1.)
x = (255*x).astype(np.uint8)
x = Image.fromarray(x)
if not x.mode == "RGB":
x = x.convert("RGB")
return x
def generate(input, rng, params):
return model.generate(
**input,
max_length=257,
num_beams=1,
do_sample=True,
prng_key=rng,
eos_token_id=50000,
pad_token_id=50000,
params=params,
)
def get_images(indices, params):
return vqgan.decode_code(indices, params=params)
def plot_images(images):
fig = plt.figure(figsize=(40, 20))
columns = 4
rows = 2
plt.subplots_adjust(hspace=0, wspace=0)
for i in range(1, columns*rows +1):
fig.add_subplot(rows, columns, i)
plt.imshow(images[i-1])
plt.gca().axes.get_yaxis().set_visible(False)
plt.show()
def stack_reconstructions(images):
w, h = images[0].size[0], images[0].size[1]
img = Image.new("RGB", (len(images)*w, h))
for i, img_ in enumerate(images):
img.paste(img_, (i*w,0))
return img
p_generate = jax.pmap(generate, "batch")
p_get_images = jax.pmap(get_images, "batch")
bart_params = replicate(model.params)
vqgan_params = replicate(vqgan.params)
# ## CLIP Scoring
from transformers import CLIPProcessor, FlaxCLIPModel
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# st.write("FlaxCLIPModel")
# print("Initialize FlaxCLIPModel")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# st.write("CLIPProcessor")
# print("Initialize CLIPProcessor")
def hallucinate(prompt, num_images=64):
prompt = [prompt] * jax.device_count()
inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
inputs = shard(inputs)
all_images = []
for i in range(num_images // jax.device_count()):
key = random.randint(0, 1e7)
rng = jax.random.PRNGKey(key)
rngs = jax.random.split(rng, jax.local_device_count())
indices = p_generate(inputs, rngs, bart_params).sequences
indices = indices[:, :, 1:]
images = p_get_images(indices, vqgan_params)
images = np.squeeze(np.asarray(images), 1)
for image in images:
all_images.append(custom_to_pil(image))
return all_images
def clip_top_k(prompt, images, k=8):
inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
outputs = clip(**inputs)
logits = outputs.logits_per_text
scores = np.array(logits[0]).argsort()[-k:][::-1]
return [images[score] for score in scores]
def captioned_strip(images, caption):
increased_h = 0 if caption is None else 48
w, h = images[0].size[0], images[0].size[1]
img = Image.new("RGB", (len(images)*w, h + increased_h))
for i, img_ in enumerate(images):
img.paste(img_, (i*w, increased_h))
if caption is not None:
draw = ImageDraw.Draw(img)
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
draw.text((20, 3), caption, (255,255,255), font=font)
return img
# Controls
num_images = st.sidebar.slider("Candidates to generate", 1, 64, 8, 1)
num_preds = st.sidebar.slider("Best predictions to show", 1, 8, 1, 1)
prompt = st.text_input("What do you want to see?")
if prompt != "":
st.write(f"Generating candidates for: {prompt}")
images = hallucinate(prompt, num_images=num_images)
images = clip_top_k(prompt, images, k=num_preds)
predictions_strip = captioned_strip(images, None)
st.image(predictions_strip)