text2imgtest / app.py
jruneofficial's picture
Update app.py
e5241f9
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import pipeline
import torch
import os
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from pytorch_pretrained_biggan import BigGAN, truncated_noise_sample, one_hot_from_names, one_hot_from_int
config = {
"model_name": "keras-io/multimodal-entailment",
"base_model_name": "distilbert-base-uncased",
"image_gen_model": "biggan-deep-512",
"max_length": 20,
"freeze_text_model": True,
"freeze_image_gen_model": True,
"text_embedding_dim": 768,
"class_embedding_dim": 128
}
truncation=0.4
is_gpu = False
device = torch.device('cuda') if is_gpu else torch.device('cpu')
print(device)
model = AutoModelForSequenceClassification.from_pretrained(config["model_name"], use_auth_token=os.environ.get(
'huggingface-api-token'))
tokenizer = AutoTokenizer.from_pretrained(config["base_model_name"])
model.to(device)
model.eval()
gan_model = BigGAN.from_pretrained(config["image_gen_model"])
gan_model.to(device)
gan_model.eval()
print("Models were loaded")
def generate_image(dense_class_vector=None, int_index=None, noise_seed_vector=None, truncation=0.4):
seed = int(noise_seed_vector.sum().item()) if noise_seed_vector is not None else None
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed)
noise_vector = torch.from_numpy(noise_vector)
if int_index is not None:
class_vector = one_hot_from_int([int_index], batch_size=1)
class_vector = torch.from_numpy(class_vector)
dense_class_vector = gan_model.embeddings(class_vector)
else:
if isinstance(dense_class_vector, np.ndarray):
dense_class_vector = torch.tensor(dense_class_vector)
dense_class_vector = dense_class_vector.view(1, 128)
input_vector = torch.cat([noise_vector, dense_class_vector], dim=1)
# Generate an image
with torch.no_grad():
output = gan_model.generator(input_vector, truncation)
output = output.cpu().numpy()
output = output.transpose((0, 2, 3, 1))
output = ((output + 1.0) / 2.0) * 256
output.clip(0, 255, out=output)
output = np.asarray(np.uint8(output[0]), dtype=np.uint8)
return output
def print_image(numpy_array):
""" Utility function to print a numpy uint8 array as an image
"""
img = Image.fromarray(numpy_array)
plt.imshow(img)
plt.show()
def text_to_image(text):
tokens = tokenizer.encode(text, add_special_tokens=True, return_tensors='pt').to(device)
with torch.no_grad():
lm_output = model(tokens, return_dict=True)
pred_int_index = torch.argmax(lm_output.logits[0], dim=-1).cpu().detach().numpy().tolist()
print(pred_int_index)
# Now generate an image (a numpy array)
numpy_image = generate_image(int_index=pred_int_index,
truncation=truncation,
noise_seed_vector=tokens)
img = Image.fromarray(numpy_image)
#print_image(numpy_image)
return img
examples = ["a high resoltuion photo of a pizza from famous food magzine.",
"this is a photo of my pet golden retriever.",
"this is a photo of a trouble some street cat.",
"a blur image of coral reef.",
"a yellow taxi cab commonly found in USA.",
"Once upon a time, there was a black ship full of pirates.",
"a photo of a large castle.",
"a sketch of an old Church"]
if __name__ == '__main__':
interFace = gr.Interface(fn=text_to_image,
inputs=gr.inputs.Textbox(placeholder="Enter the text to generate an image", label="Text "
"query",
lines=1),
outputs=gr.outputs.Image(type="auto", label="Generated Image"),
verbose=True,
examples=examples,
title="Generate Image from Text",
description="",
theme="huggingface")
interFace.launch()