jruneofficial commited on
Commit
2d0cabb
1 Parent(s): 2573f42

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoModelForSequenceClassification
5
+ from transformers import AutoTokenizer
6
+ from transformers import pipeline
7
+
8
+ import torch
9
+ import os
10
+ import numpy as np
11
+ from matplotlib import pyplot as plt
12
+ from PIL import Image
13
+
14
+ from pytorch_pretrained_biggan import BigGAN, truncated_noise_sample, one_hot_from_names, one_hot_from_int
15
+
16
+ config = {
17
+ "model_name": "smangrul/Multimodal-Challenge",
18
+ "base_model_name": "distilbert-base-uncased",
19
+ "image_gen_model": "biggan-deep-128",
20
+ "max_length": 20,
21
+ "freeze_text_model": True,
22
+ "freeze_image_gen_model": True,
23
+ "text_embedding_dim": 768,
24
+ "class_embedding_dim": 128
25
+ }
26
+ truncation=0.4
27
+
28
+ is_gpu = False
29
+ device = torch.device('cuda') if is_gpu else torch.device('cpu')
30
+ print(device)
31
+
32
+ model = AutoModelForSequenceClassification.from_pretrained(config["model_name"], use_auth_token=os.environ.get(
33
+ 'huggingface-api-token'))
34
+ tokenizer = AutoTokenizer.from_pretrained(config["base_model_name"])
35
+ model.to(device)
36
+ model.eval()
37
+
38
+ gan_model = BigGAN.from_pretrained(config["image_gen_model"])
39
+ gan_model.to(device)
40
+ gan_model.eval()
41
+ print("Models were loaded")
42
+
43
+
44
+ def generate_image(dense_class_vector=None, int_index=None, noise_seed_vector=None, truncation=0.4):
45
+ seed = int(noise_seed_vector.sum().item()) if noise_seed_vector is not None else None
46
+ noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed)
47
+ noise_vector = torch.from_numpy(noise_vector)
48
+ if int_index is not None:
49
+ class_vector = one_hot_from_int([int_index], batch_size=1)
50
+ class_vector = torch.from_numpy(class_vector)
51
+ dense_class_vector = gan_model.embeddings(class_vector)
52
+ else:
53
+ if isinstance(dense_class_vector, np.ndarray):
54
+ dense_class_vector = torch.tensor(dense_class_vector)
55
+ dense_class_vector = dense_class_vector.view(1, 128)
56
+
57
+ input_vector = torch.cat([noise_vector, dense_class_vector], dim=1)
58
+
59
+ # Generate an image
60
+ with torch.no_grad():
61
+ output = gan_model.generator(input_vector, truncation)
62
+ output = output.cpu().numpy()
63
+ output = output.transpose((0, 2, 3, 1))
64
+ output = ((output + 1.0) / 2.0) * 256
65
+ output.clip(0, 255, out=output)
66
+ output = np.asarray(np.uint8(output[0]), dtype=np.uint8)
67
+ return output
68
+
69
+
70
+ def print_image(numpy_array):
71
+ """ Utility function to print a numpy uint8 array as an image
72
+ """
73
+ img = Image.fromarray(numpy_array)
74
+ plt.imshow(img)
75
+ plt.show()
76
+
77
+
78
+ def text_to_image(text):
79
+ tokens = tokenizer.encode(text, add_special_tokens=True, return_tensors='pt').to(device)
80
+ with torch.no_grad():
81
+ lm_output = model(tokens, return_dict=True)
82
+ pred_int_index = torch.argmax(lm_output.logits[0], dim=-1).cpu().detach().numpy().tolist()
83
+ print(pred_int_index)
84
+
85
+ # Now generate an image (a numpy array)
86
+ numpy_image = generate_image(int_index=pred_int_index,
87
+ truncation=truncation,
88
+ noise_seed_vector=tokens)
89
+
90
+ img = Image.fromarray(numpy_image)
91
+ #print_image(numpy_image)
92
+ return img
93
+
94
+ examples = ["a high resoltuion photo of a pizza from famous food magzine.",
95
+ "this is a photo of my pet golden retriever.",
96
+ "this is a photo of a trouble some street cat.",
97
+ "a blur image of coral reef.",
98
+ "a yellow taxi cab commonly found in USA.",
99
+ "Once upon a time, there was a black ship full of pirates.",
100
+ "a photo of a large castle.",
101
+ "a sketch of an old Church"]
102
+
103
+ if __name__ == '__main__':
104
+ interFace = gr.Interface(fn=text_to_image,
105
+ inputs=gr.inputs.Textbox(placeholder="Enter the text to generate an image", label="Text "
106
+ "query",
107
+ lines=1),
108
+ outputs=gr.outputs.Image(type="auto", label="Generated Image"),
109
+ verbose=True,
110
+ examples=examples,
111
+ title="Generate Image from Text",
112
+ description="",
113
+ theme="huggingface")
114
+ interFace.launch()