#!/usr/bin/env python # coding: utf-8 # Uncomment to run on cpu #import os #os.environ["JAX_PLATFORM_NAME"] = "cpu" import random import jax import flax.linen as nn from flax.training.common_utils import shard from flax.jax_utils import replicate, unreplicate from transformers.models.bart.modeling_flax_bart import * from transformers import BartTokenizer, FlaxBartForConditionalGeneration import requests from PIL import Image import numpy as np import matplotlib.pyplot as plt from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel import streamlit as st st.write("Loading model...") # TODO: set those args in a config file OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos BOS_TOKEN_ID = 16384 BASE_MODEL = 'flax-community/dalle-mini' class CustomFlaxBartModule(FlaxBartModule): def setup(self): # we keep shared to easily load pre-trained weights self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), dtype=self.dtype, ) # a separate embedding is used for the decoder self.decoder_embed = nn.Embed( OUTPUT_VOCAB_SIZE, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), dtype=self.dtype, ) self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) # the decoder has a different config decoder_config = BartConfig(self.config.to_dict()) decoder_config.max_position_embeddings = OUTPUT_LENGTH decoder_config.vocab_size = OUTPUT_VOCAB_SIZE self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed) class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule): def setup(self): self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype) self.lm_head = nn.Dense( OUTPUT_VOCAB_SIZE, use_bias=False, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), ) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE)) class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration): module_class = CustomFlaxBartForConditionalGenerationModule # create our model # FIXME: Save tokenizer to hub so we can load from there tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") model = CustomFlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL) model.config.force_bos_token_to_be_generated = False model.config.forced_bos_token_id = None model.config.forced_eos_token_id = None vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384") def custom_to_pil(x): x = np.clip(x, 0., 1.) x = (255*x).astype(np.uint8) x = Image.fromarray(x) if not x.mode == "RGB": x = x.convert("RGB") return x def generate(input, rng, params): return model.generate( **input, max_length=257, num_beams=1, do_sample=True, prng_key=rng, eos_token_id=50000, pad_token_id=50000, params=params, ) def get_images(indices, params): return vqgan.decode_code(indices, params=params) def plot_images(images): fig = plt.figure(figsize=(40, 20)) columns = 4 rows = 2 plt.subplots_adjust(hspace=0, wspace=0) for i in range(1, columns*rows +1): fig.add_subplot(rows, columns, i) plt.imshow(images[i-1]) plt.gca().axes.get_yaxis().set_visible(False) plt.show() def stack_reconstructions(images): w, h = images[0].size[0], images[0].size[1] img = Image.new("RGB", (len(images)*w, h)) for i, img_ in enumerate(images): img.paste(img_, (i*w,0)) return img p_generate = jax.pmap(generate, "batch") p_get_images = jax.pmap(get_images, "batch") bart_params = replicate(model.params) vqgan_params = replicate(vqgan.params) # ## CLIP Scoring from transformers import CLIPProcessor, FlaxCLIPModel clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") # st.write("FlaxCLIPModel") # print("Initialize FlaxCLIPModel") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # st.write("CLIPProcessor") # print("Initialize CLIPProcessor") def hallucinate(prompt, num_images=64): prompt = [prompt] * jax.device_count() inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data inputs = shard(inputs) all_images = [] for i in range(num_images // jax.device_count()): key = random.randint(0, 1e7) rng = jax.random.PRNGKey(key) rngs = jax.random.split(rng, jax.local_device_count()) indices = p_generate(inputs, rngs, bart_params).sequences indices = indices[:, :, 1:] images = p_get_images(indices, vqgan_params) images = np.squeeze(np.asarray(images), 1) for image in images: all_images.append(custom_to_pil(image)) return all_images def clip_top_k(prompt, images, k=8): inputs = processor(text=prompt, images=images, return_tensors="np", padding=True) outputs = clip(**inputs) logits = outputs.logits_per_text scores = np.array(logits[0]).argsort()[-k:][::-1] return [images[score] for score in scores] def captioned_strip(images, caption): increased_h = 0 if caption is None else 48 w, h = images[0].size[0], images[0].size[1] img = Image.new("RGB", (len(images)*w, h + increased_h)) for i, img_ in enumerate(images): img.paste(img_, (i*w, increased_h)) if caption is not None: draw = ImageDraw.Draw(img) font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40) draw.text((20, 3), caption, (255,255,255), font=font) return img # Controls num_images = st.sidebar.slider("Candidates to generate", 1, 64, 8, 1) num_preds = st.sidebar.slider("Best predictions to show", 1, 8, 1, 1) prompt = st.text_input("What do you want to see?") if prompt != "": st.write(f"Generating candidates for: {prompt}") images = hallucinate(prompt, num_images=num_images) images = clip_top_k(prompt, images, k=num_preds) predictions_strip = captioned_strip(images, None) st.image(predictions_strip)