ryaalbr commited on
Commit
e7f40b6
1 Parent(s): d486d0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -11
app.py CHANGED
@@ -8,12 +8,37 @@ import clip
8
  import pickle
9
  import requests
10
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # # Load the pre-trained model and processor
15
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
16
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
17
 
18
  #orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
19
 
@@ -21,12 +46,19 @@ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
21
  # Load the Unsplash dataset
22
  dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
23
 
 
 
 
 
 
 
 
24
  height = 256 # height for resizing images
25
 
26
  def predict(image, labels):
27
  with torch.no_grad():
28
- inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
29
- outputs = model(**inputs)
30
  logits_per_image = outputs.logits_per_image # this is the image-text similarity score
31
  probs = logits_per_image.softmax(dim=1).cpu().numpy() # we can take the softmax to get the label probabilities
32
  return {k: float(v) for k, v in zip(labels, probs[0])}
@@ -50,11 +82,103 @@ def rand_image():
50
  def set_labels(text):
51
  return text.split(",")
52
 
53
- get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=environ["api_key"])
54
- def generate_text(image, model_name):
55
- return get_caption(image, model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # def search_images(text):
59
  # return get_images(text, api_name="images")
60
 
@@ -68,8 +192,8 @@ def search(search_query):
68
  with torch.no_grad():
69
 
70
  # Encode and normalize the description using CLIP (HF CLIP)
71
- inputs = processor(text=search_query, images=None, return_tensors="pt", padding=True)
72
- text_encoded = model.get_text_features(**inputs)
73
 
74
  # # Encode and normalize the description using CLIP (original CLIP)
75
  # text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
@@ -163,7 +287,7 @@ with gr.Blocks() as demo:
163
  caption = gr.Textbox(label='Caption', elem_classes="caption-text")
164
  get_btn_cap.click(fn=rand_image, outputs=im_cap)
165
  #im_cap.change(generate_text, inputs=im_cap, outputs=caption)
166
- caption_btn.click(generate_text, inputs=[im_cap, model_name], outputs=caption)
167
 
168
  with gr.Tab("Search"):
169
  instructions = """## Instructions:
 
8
  import pickle
9
  import requests
10
  import torch
11
+ import os
12
+ from huggingface_hub import hf_hub_download
13
+ from torch import nn
14
+ import torch.nn.functional as nnf
15
+ import sys
16
+ from typing import Tuple, List, Union, Optional
17
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
18
+
19
+
20
+ N = type(None)
21
+ V = np.array
22
+ ARRAY = np.ndarray
23
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
24
+ VS = Union[Tuple[V, ...], List[V]]
25
+ VN = Union[V, N]
26
+ VNS = Union[VS, N]
27
+ T = torch.Tensor
28
+ TS = Union[Tuple[T, ...], List[T]]
29
+ TN = Optional[T]
30
+ TNS = Union[Tuple[TN, ...], List[TN]]
31
+ TSN = Optional[TS]
32
+ TA = Union[T, ARRAY]
33
+
34
+ D = torch.device
35
+ CPU = torch.device('cpu')
36
 
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
 
39
  # # Load the pre-trained model and processor
40
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
41
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
42
 
43
  #orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
44
 
 
46
  # Load the Unsplash dataset
47
  dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
48
 
49
+ # Load gpt and modifed weights for captions
50
+ gpt = GPT2LMHeadModel.from_pretrained('gpt2')
51
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
52
+ conceptual_weight = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-conceptual-weights", filename="conceptual_weights.pt")
53
+ coco_weight = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-COCO-weights", filename="coco_weights.pt")
54
+
55
+
56
  height = 256 # height for resizing images
57
 
58
  def predict(image, labels):
59
  with torch.no_grad():
60
+ inputs = clip_processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
61
+ outputs = clip_model(**inputs)
62
  logits_per_image = outputs.logits_per_image # this is the image-text similarity score
63
  probs = logits_per_image.softmax(dim=1).cpu().numpy() # we can take the softmax to get the label probabilities
64
  return {k: float(v) for k, v in zip(labels, probs[0])}
 
82
  def set_labels(text):
83
  return text.split(",")
84
 
85
+ # get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=environ["api_key"])
86
+ # def generate_text(image, model_name):
87
+ # return get_caption(image, model_name)
88
+
89
+
90
+ class MLP(nn.Module):
91
+
92
+ def forward(self, x: T) -> T:
93
+ return self.model(x)
94
+
95
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
96
+ super(MLP, self).__init__()
97
+ layers = []
98
+ for i in range(len(sizes) -1):
99
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
100
+ if i < len(sizes) - 2:
101
+ layers.append(act())
102
+ self.model = nn.Sequential(*layers)
103
+
104
+
105
+ class ClipCaptionModel(nn.Module):
106
+
107
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
108
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
109
+
110
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
111
+ embedding_text = self.gpt.transformer.wte(tokens)
112
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
113
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
114
+ if labels is not None:
115
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
116
+ labels = torch.cat((dummy_token, tokens), dim=1)
117
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
118
+ return out
119
+
120
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
121
+ super(ClipCaptionModel, self).__init__()
122
+ self.prefix_length = prefix_length
123
+ self.gpt = gpt
124
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
125
+ if prefix_length > 10: # not enough memory
126
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
127
+ else:
128
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
129
+
130
+ #clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
131
+
132
+
133
+ def get_caption(img,model_name):
134
+ prefix_length = 10
135
+
136
+ model = ClipCaptionModel(prefix_length)
137
+
138
+ if model_name == "COCO":
139
+ model_path = coco_weight
140
+ else:
141
+ model_path = conceptual_weight
142
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
143
+ model = model.eval()
144
+ model = model.to(device)
145
+
146
+ input = clip_processor(images=img, return_tensors="pt").to(device)
147
+ with torch.no_grad():
148
+ prefix = clip_model.get_image_features(**input)
149
+
150
+ # image = preprocess(img).unsqueeze(0).to(device)
151
+ # with torch.no_grad():
152
+ # prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
153
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
154
+ output = model.gpt.generate(inputs_embeds=prefix_embed,
155
+ num_beams=1,
156
+ do_sample=False,
157
+ num_return_sequences=1,
158
+ no_repeat_ngram_size=1,
159
+ max_new_tokens = 67,
160
+ pad_token_id = tokenizer.eos_token_id,
161
+ eos_token_id = tokenizer.encode('.')[0],
162
+ renormalize_logits = True)
163
+ generated_text_prefix = tokenizer.decode(output[0], skip_special_tokens=True)
164
+ return generated_text_prefix[:-1] if generated_text_prefix[-1] == "." else generated_text_prefix #remove period at end if present
165
 
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+
175
+
176
+
177
+
178
+
179
+
180
+
181
+ # get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"])
182
  # def search_images(text):
183
  # return get_images(text, api_name="images")
184
 
 
192
  with torch.no_grad():
193
 
194
  # Encode and normalize the description using CLIP (HF CLIP)
195
+ inputs = clip_processor(text=search_query, images=None, return_tensors="pt", padding=True)
196
+ text_encoded = clip_model.get_text_features(**inputs)
197
 
198
  # # Encode and normalize the description using CLIP (original CLIP)
199
  # text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
 
287
  caption = gr.Textbox(label='Caption', elem_classes="caption-text")
288
  get_btn_cap.click(fn=rand_image, outputs=im_cap)
289
  #im_cap.change(generate_text, inputs=im_cap, outputs=caption)
290
+ caption_btn.click(get_caption, inputs=[im_cap, model_name], outputs=caption)
291
 
292
  with gr.Tab("Search"):
293
  instructions = """## Instructions: