Spaces:
Runtime error
Runtime error
File size: 6,612 Bytes
87468ed 8d5ee19 87468ed bdd3b94 87468ed 7988f50 87468ed fe96438 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
#!/usr/bin/env python
# coding: utf-8
import os
# Uncomment to run on cpu
# os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["WANDB_DISABLED"] = "true"
os.environ['WANDB_SILENT']="true"
import random
import re
import torch
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard, shard_prng_key
from PIL import Image, ImageDraw, ImageFont
from functools import partial
from transformers import CLIPProcessor, FlaxCLIPModel, AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
DALLE_REPO = "dalle-mini/dalle-mini/mini-1:v0"
DALLE_COMMIT_ID = None
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
model, params = DalleBart.from_pretrained(
DALLE_REPO, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
)
vqgan, vqgan_params = VQModel.from_pretrained(
VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
viz_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
def captioned_strip(images, caption=None, rows=1):
increased_h = 0 if caption is None else 24
w, h = images[0].size[0], images[0].size[1]
img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
for i, img_ in enumerate(images):
img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
if caption is not None:
draw = ImageDraw.Draw(img)
font = ImageFont.truetype(
"LiberationMono-Bold.ttf", 7
)
draw.text((20, 3), caption, (255, 255, 255), font=font)
return img
def get_images(indices, params):
return vqgan.decode_code(indices, params=params)
def predict_caption(image, max_length=128, num_beams=4):
image = image.convert('RGB')
image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
caption_ids = viz_model.generate(image, max_length = max_length)[0]
caption_text = clean_text(tokenizer.decode(caption_ids))
return caption_text
# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
return model.generate(
**tokenized_prompt,
prng_key=key,
params=params,
top_k=top_k,
top_p=top_p,
temperature=temperature,
condition_scale=condition_scale,
)
# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
return vqgan.decode_code(indices, params=params)
p_get_images = jax.pmap(get_images, "batch")
params = replicate(params)
vqgan_params = replicate(vqgan_params)
processor = DalleBartProcessor.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
print("Initialized DalleBartProcessor")
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
print("Initialized FlaxCLIPModel")
def hallucinate(prompt, num_images=8):
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0
print(f"Prompts: {prompt}")
prompt = [prompt] * jax.device_count()
inputs = processor(prompt)
inputs = replicate(inputs)
# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)
images = []
for i in range(max(num_images // jax.device_count(), 1)):
key, subkey = jax.random.split(key)
encoded_images = p_generate(
inputs,
shard_prng_key(subkey),
params,
gen_top_k,
gen_top_p,
temperature,
cond_scale,
)
print(f"Encoded image {i}")
# remove BOS
encoded_images = encoded_images.sequences[..., 1:]
# decode images
decoded_images = p_decode(encoded_images, vqgan_params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
for decoded_img in decoded_images:
img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
images.append(img)
print(f"Finished decoding image {i}")
return images
def run_inference(prompt, num_roundtrips=3, num_images=1):
outputs = []
for i in range(int(num_roundtrips)):
images = hallucinate(prompt, num_images=num_images)
image = images[0]
print("Generated image")
caption = predict_caption(image)
print(f"Predicted caption: {caption}")
output_title = f"""
<font size="+3">
<b>[Roundtrip {i}]</b><br>
Prompt: {prompt}<br>
馃 :<br></font>"""
output_caption = f"""
<font size="+3">
馃馃挰 : {caption}<br>
</font>
"""
outputs.append(output_title)
outputs.append(image)
outputs.append(output_caption)
prompt = caption
print("Done.")
return outputs
inputs = gr.inputs.Textbox(label="What prompt do you want to start with?", default="cookie monster the horror movie")
# num_roundtrips = gr.inputs.Number(default=2, label="How many roundtrips?")
num_roundtrips = 3
outputs = []
for _ in range(int(num_roundtrips)):
outputs.append(gr.outputs.HTML(label=""))
outputs.append(gr.Image(label=""))
outputs.append(gr.outputs.HTML(label=""))
description = """
Round trip DALL路E-mini iterates between DALL路E generation and image captioning, inspired by round trip translation! FYI: runtime is forever (~1hr or possibly longer) because the app is running on CPU.
"""
article = "<p style='text-align: center'>Put together by: Najoung Kim | Dall-E Mini code from flax-community/dalle-mini | Caption code from SRDdev/Image-Caption</p>"
gr.Interface(
fn=run_inference,
inputs=[inputs],
outputs=outputs,
title="Round Trip DALL路E mini 馃馃攣馃馃挰",
description=description,
article=article,
theme="default",
css = ".output-image, .input-image, .image-preview {height: 256px !important} "
).launch(enable_queue=False)
|