ryaalbr commited on
Commit
4150f63
1 Parent(s): a3cfe93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -12,8 +12,8 @@ import torch
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # # Load the pre-trained model and processor
15
- # model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
16
- # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
17
 
18
  orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
19
 
@@ -64,10 +64,19 @@ with open(emb_filename, 'rb') as emb:
64
 
65
  def search(search_query):
66
 
 
 
 
 
67
  with torch.no_grad():
68
- # Encode and normalize the description using CLIP
69
- text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
70
- text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
 
 
 
 
 
71
 
72
 
73
  # Retrieve the description vector
@@ -136,8 +145,8 @@ with gr.Blocks(css=".caption-text {font-size: 40px !important;}") as demo:
136
  label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
137
  label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
138
  get_btn.click(fn=rand_image, outputs=im)
139
- im.change(predict2, inputs=[im, labels], outputs=cf)
140
- reclass_btn.click(predict2, inputs=[im, labels], outputs=cf)
141
 
142
  with gr.Tab("Image Captioning"):
143
  with gr.Row():
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # # Load the pre-trained model and processor
15
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
16
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
17
 
18
  orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
19
 
 
64
 
65
  def search(search_query):
66
 
67
+
68
+
69
+
70
+
71
  with torch.no_grad():
72
+
73
+ # Encode and normalize the description using CLIP (HF CLIP)
74
+ inputs = processor(text=[text], images=None, return_tensors="pt", padding=True)
75
+ text_encoded = model.get_text_features(**inputs)
76
+
77
+ # # Encode and normalize the description using CLIP (original CLIP)
78
+ # text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
79
+ # text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
80
 
81
 
82
  # Retrieve the description vector
 
145
  label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
146
  label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
147
  get_btn.click(fn=rand_image, outputs=im)
148
+ im.change(predict, inputs=[im, labels], outputs=cf)
149
+ reclass_btn.click(predict, inputs=[im, labels], outputs=cf)
150
 
151
  with gr.Tab("Image Captioning"):
152
  with gr.Row():