import gradio as gr import torch from transformers import AutoModelForSequenceClassification from transformers import AutoTokenizer from transformers import pipeline import torch import os import numpy as np from matplotlib import pyplot as plt from PIL import Image from pytorch_pretrained_biggan import BigGAN, truncated_noise_sample, one_hot_from_names, one_hot_from_int config = { "model_name": "keras-io/multimodal-entailment", "base_model_name": "distilbert-base-uncased", "image_gen_model": "biggan-deep-512", "max_length": 20, "freeze_text_model": True, "freeze_image_gen_model": True, "text_embedding_dim": 768, "class_embedding_dim": 128 } truncation=0.4 is_gpu = False device = torch.device('cuda') if is_gpu else torch.device('cpu') print(device) model = AutoModelForSequenceClassification.from_pretrained(config["model_name"], use_auth_token=os.environ.get( 'huggingface-api-token')) tokenizer = AutoTokenizer.from_pretrained(config["base_model_name"]) model.to(device) model.eval() gan_model = BigGAN.from_pretrained(config["image_gen_model"]) gan_model.to(device) gan_model.eval() print("Models were loaded") def generate_image(dense_class_vector=None, int_index=None, noise_seed_vector=None, truncation=0.4): seed = int(noise_seed_vector.sum().item()) if noise_seed_vector is not None else None noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed) noise_vector = torch.from_numpy(noise_vector) if int_index is not None: class_vector = one_hot_from_int([int_index], batch_size=1) class_vector = torch.from_numpy(class_vector) dense_class_vector = gan_model.embeddings(class_vector) else: if isinstance(dense_class_vector, np.ndarray): dense_class_vector = torch.tensor(dense_class_vector) dense_class_vector = dense_class_vector.view(1, 128) input_vector = torch.cat([noise_vector, dense_class_vector], dim=1) # Generate an image with torch.no_grad(): output = gan_model.generator(input_vector, truncation) output = output.cpu().numpy() output = output.transpose((0, 2, 3, 1)) output = ((output + 1.0) / 2.0) * 256 output.clip(0, 255, out=output) output = np.asarray(np.uint8(output[0]), dtype=np.uint8) return output def print_image(numpy_array): """ Utility function to print a numpy uint8 array as an image """ img = Image.fromarray(numpy_array) plt.imshow(img) plt.show() def text_to_image(text): tokens = tokenizer.encode(text, add_special_tokens=True, return_tensors='pt').to(device) with torch.no_grad(): lm_output = model(tokens, return_dict=True) pred_int_index = torch.argmax(lm_output.logits[0], dim=-1).cpu().detach().numpy().tolist() print(pred_int_index) # Now generate an image (a numpy array) numpy_image = generate_image(int_index=pred_int_index, truncation=truncation, noise_seed_vector=tokens) img = Image.fromarray(numpy_image) #print_image(numpy_image) return img examples = ["a high resoltuion photo of a pizza from famous food magzine.", "this is a photo of my pet golden retriever.", "this is a photo of a trouble some street cat.", "a blur image of coral reef.", "a yellow taxi cab commonly found in USA.", "Once upon a time, there was a black ship full of pirates.", "a photo of a large castle.", "a sketch of an old Church"] if __name__ == '__main__': interFace = gr.Interface(fn=text_to_image, inputs=gr.inputs.Textbox(placeholder="Enter the text to generate an image", label="Text " "query", lines=1), outputs=gr.outputs.Image(type="auto", label="Generated Image"), verbose=True, examples=examples, title="Generate Image from Text", description="", theme="huggingface") interFace.launch()