#!/usr/bin/env python # coding: utf-8 import os # Uncomment to run on cpu # os.environ["JAX_PLATFORM_NAME"] = "cpu" os.environ["WANDB_DISABLED"] = "true" os.environ['WANDB_SILENT']="true" import random import re import torch import gradio as gr import jax import jax.numpy as jnp import numpy as np from flax.jax_utils import replicate from flax.training.common_utils import shard, shard_prng_key from PIL import Image, ImageDraw, ImageFont from functools import partial from transformers import CLIPProcessor, FlaxCLIPModel, AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel from dalle_mini import DalleBart, DalleBartProcessor from vqgan_jax.modeling_flax_vqgan import VQModel DALLE_REPO = "dalle-mini/dalle-mini/mini-1:v0" DALLE_COMMIT_ID = None VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384" VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9" model, params = DalleBart.from_pretrained( DALLE_REPO, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False ) vqgan, vqgan_params = VQModel.from_pretrained( VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False ) device = 'cuda' if torch.cuda.is_available() else 'cpu' encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning" decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning" model_checkpoint = "nlpconnect/vit-gpt2-image-captioning" feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint) tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint) viz_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device) def captioned_strip(images, caption=None, rows=1): increased_h = 0 if caption is None else 24 w, h = images[0].size[0], images[0].size[1] img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h)) for i, img_ in enumerate(images): img.paste(img_, (i // rows * w, increased_h + (i % rows) * h)) if caption is not None: draw = ImageDraw.Draw(img) font = ImageFont.truetype( "LiberationMono-Bold.ttf", 7 ) draw.text((20, 3), caption, (255, 255, 255), font=font) return img def get_images(indices, params): return vqgan.decode_code(indices, params=params) def predict_caption(image, max_length=128, num_beams=4): image = image.convert('RGB') image = feature_extractor(image, return_tensors="pt").pixel_values.to(device) clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0] caption_ids = viz_model.generate(image, max_length = max_length)[0] caption_text = clean_text(tokenizer.decode(caption_ids)) return caption_text # model inference @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6)) def p_generate( tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale ): return model.generate( **tokenized_prompt, prng_key=key, params=params, top_k=top_k, top_p=top_p, temperature=temperature, condition_scale=condition_scale, ) # decode image @partial(jax.pmap, axis_name="batch") def p_decode(indices, params): return vqgan.decode_code(indices, params=params) p_get_images = jax.pmap(get_images, "batch") params = replicate(params) vqgan_params = replicate(vqgan_params) processor = DalleBartProcessor.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID) print("Initialized DalleBartProcessor") clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") print("Initialized FlaxCLIPModel") def hallucinate(prompt, num_images=8): gen_top_k = None gen_top_p = None temperature = None cond_scale = 10.0 print(f"Prompts: {prompt}") prompt = [prompt] * jax.device_count() inputs = processor(prompt) inputs = replicate(inputs) # create a random key seed = random.randint(0, 2**32 - 1) key = jax.random.PRNGKey(seed) images = [] for i in range(max(num_images // jax.device_count(), 1)): key, subkey = jax.random.split(key) encoded_images = p_generate( inputs, shard_prng_key(subkey), params, gen_top_k, gen_top_p, temperature, cond_scale, ) print(f"Encoded image {i}") # remove BOS encoded_images = encoded_images.sequences[..., 1:] # decode images decoded_images = p_decode(encoded_images, vqgan_params) decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) for decoded_img in decoded_images: img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8)) images.append(img) print(f"Finished decoding image {i}") return images def run_inference(prompt, num_roundtrips=3, num_images=1): outputs = [] for i in range(int(num_roundtrips)): images = hallucinate(prompt, num_images=num_images) image = images[0] print("Generated image") caption = predict_caption(image) print(f"Predicted caption: {caption}") output_title = f""" [Roundtrip {i}]
Prompt: {prompt}
🥑 :
""" output_caption = f""" 🤖💬 : {caption}
""" outputs.append(output_title) outputs.append(image) outputs.append(output_caption) prompt = caption print("Done.") return outputs inputs = gr.inputs.Textbox(label="What prompt do you want to start with?", default="cookie monster the horror movie") # num_roundtrips = gr.inputs.Number(default=2, label="How many roundtrips?") num_roundtrips = 3 outputs = [] for _ in range(int(num_roundtrips)): outputs.append(gr.outputs.HTML(label="")) outputs.append(gr.Image(label="")) outputs.append(gr.outputs.HTML(label="")) description = """ Round trip DALL·E-mini iterates between DALL·E generation and image captioning, inspired by round trip translation! FYI: runtime is forever (~1hr or possibly longer) because the app is running on CPU. """ article = "

Put together by: Najoung Kim | Dall-E Mini code from flax-community/dalle-mini | Caption code from SRDdev/Image-Caption

" gr.Interface( fn=run_inference, inputs=[inputs], outputs=outputs, title="Round Trip DALL·E mini 🥑🔁🤖💬", description=description, article=article, theme="default", css = ".output-image, .input-image, .image-preview {height: 256px !important} " ).launch(enable_queue=False)