File size: 3,836 Bytes
e4174c1
 
 
 
 
384dde1
e4174c1
e461401
 
 
e4174c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384dde1
e4174c1
 
 
 
 
 
 
 
 
 
 
57e2f92
 
c2592bc
 
e461401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa90bae
e461401
 
c2592bc
 
591bdf5
 
c2592bc
e461401
c2592bc
3fd0786
 
 
384dde1
 
 
3fd0786
c2592bc
 
 
f247552
 
c2592bc
384dde1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from flax.jax_utils import replicate
from jax import pmap
from flax.training.common_utils import shard
import jax
import jax.numpy as jnp
import gradio as gr

from PIL import Image
from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel

from pathlib import Path
from PIL import Image
import numpy as np

from diffusers import FlaxStableDiffusionPipeline


import os
if 'TPU_NAME' in os.environ:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1


    from jax.config import config
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

import jax
jax.local_devices()
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")

def sd2_inference(pipeline, prompts, params, seed = 42, num_inference_steps = 50 ):
    prng_seed = jax.random.PRNGKey(seed)
    prompt_ids = pipeline.prepare_inputs(prompts)
    params = replicate(params)
    prng_seed = jax.random.split(prng_seed, jax.device_count())
    prompt_ids = shard(prompt_ids)
    images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    images = pipeline.numpy_to_pil(images)
    return images



HF_ACCESS_TOKEN = os.environ["HFAUTH"]

# Load Model
# - Reference: https://github.com/huggingface/diffusers/blob/main/README.md
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    use_auth_token = HF_ACCESS_TOKEN,
    revision="bf16",
    dtype=jnp.bfloat16,
)



loc = "ydshieh/vit-gpt2-coco-en"

feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
tokenizer = AutoTokenizer.from_pretrained(loc)
model = FlaxVisionEncoderDecoderModel.from_pretrained(loc)

gen_kwargs = {"max_length": 16, "num_beams": 4}


# This takes sometime when compiling the first time, but the subsequent inference will be much faster

def generate(pixel_values):
    output_ids = model.generate(pixel_values, **gen_kwargs).sequences
    return output_ids
    
    
def predict(image):

    pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
    output_ids = generate(pixel_values)
    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]
    
    return preds
    
def image2text(image):
    preds = predict(image)
    return (preds[0])


def text_to_image_and_image_to_text(text=None,image=None):
    txt=None
    img=None
    if image != None:
        txt=image2text(image)
    if text !=None:
        images = sd2_inference(pipeline, [text], params, seed = 42, num_inference_steps = 5 )
        img = images[0]
    return img,txt


if __name__ == '__main__':
    interFace = gr.Interface(fn=text_to_image_and_image_to_text,
                             inputs=[gr.inputs.Textbox(placeholder="Enter the text to Encode to an image", label="Text to Encode to Image ",lines=1,optional=True),gr.Image(type="pil",label="Image to Decode to text",optional=True)],
                             outputs=[gr.outputs.Image(type="pil", label="Encoded Image"),gr.outputs.Textbox( label="Decoded Text")],
                             title="T2I2T",
                             description="T2I2T: Text2Image2Text imformation transmiter",
                             theme='gradio/soft'
                             )
    interFace.launch()