Spaces:
Sleeping
Sleeping
File size: 3,836 Bytes
e4174c1 384dde1 e4174c1 e461401 e4174c1 384dde1 e4174c1 57e2f92 c2592bc e461401 fa90bae e461401 c2592bc 591bdf5 c2592bc e461401 c2592bc 3fd0786 384dde1 3fd0786 c2592bc f247552 c2592bc 384dde1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from flax.jax_utils import replicate
from jax import pmap
from flax.training.common_utils import shard
import jax
import jax.numpy as jnp
import gradio as gr
from PIL import Image
from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
from pathlib import Path
from PIL import Image
import numpy as np
from diffusers import FlaxStableDiffusionPipeline
import os
if 'TPU_NAME' in os.environ:
import requests
if 'TPU_DRIVER_MODE' not in globals():
url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
print('No TPU detected. Can be changed under "Runtime/Change runtime type".')
import jax
jax.local_devices()
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
def sd2_inference(pipeline, prompts, params, seed = 42, num_inference_steps = 50 ):
prng_seed = jax.random.PRNGKey(seed)
prompt_ids = pipeline.prepare_inputs(prompts)
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
return images
HF_ACCESS_TOKEN = os.environ["HFAUTH"]
# Load Model
# - Reference: https://github.com/huggingface/diffusers/blob/main/README.md
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token = HF_ACCESS_TOKEN,
revision="bf16",
dtype=jnp.bfloat16,
)
loc = "ydshieh/vit-gpt2-coco-en"
feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
tokenizer = AutoTokenizer.from_pretrained(loc)
model = FlaxVisionEncoderDecoderModel.from_pretrained(loc)
gen_kwargs = {"max_length": 16, "num_beams": 4}
# This takes sometime when compiling the first time, but the subsequent inference will be much faster
def generate(pixel_values):
output_ids = model.generate(pixel_values, **gen_kwargs).sequences
return output_ids
def predict(image):
pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
output_ids = generate(pixel_values)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
def image2text(image):
preds = predict(image)
return (preds[0])
def text_to_image_and_image_to_text(text=None,image=None):
txt=None
img=None
if image != None:
txt=image2text(image)
if text !=None:
images = sd2_inference(pipeline, [text], params, seed = 42, num_inference_steps = 5 )
img = images[0]
return img,txt
if __name__ == '__main__':
interFace = gr.Interface(fn=text_to_image_and_image_to_text,
inputs=[gr.inputs.Textbox(placeholder="Enter the text to Encode to an image", label="Text to Encode to Image ",lines=1,optional=True),gr.Image(type="pil",label="Image to Decode to text",optional=True)],
outputs=[gr.outputs.Image(type="pil", label="Encoded Image"),gr.outputs.Textbox( label="Decoded Text")],
title="T2I2T",
description="T2I2T: Text2Image2Text imformation transmiter",
theme='gradio/soft'
)
interFace.launch() |