ryaalbr commited on
Commit
971e64d
1 Parent(s): 6760107

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -9,12 +9,14 @@ import pickle
9
  import requests
10
  import torch
11
 
12
- is_gpu = False
13
- device = CUDA(0) if is_gpu else "cpu"
 
 
 
 
 
14
 
15
- # Load the pre-trained model and processor
16
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
17
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
18
 
19
  # Load the Unsplash dataset
20
  dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
@@ -28,6 +30,17 @@ def predict(image, labels):
28
  probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
29
  return {k: float(v) for k, v in zip(labels, probs[0])}
30
 
 
 
 
 
 
 
 
 
 
 
 
31
  def rand_image():
32
  n = dataset.num_rows
33
  r = random.randrange(0,n)
@@ -48,7 +61,6 @@ emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
48
  with open(emb_filename, 'rb') as emb:
49
  id2url, img_names, img_emb = pickle.load(emb)
50
 
51
- orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
52
 
53
  def search(search_query):
54
 
@@ -124,8 +136,8 @@ with gr.Blocks() as demo:
124
  label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
125
  label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
126
  get_btn.click(fn=rand_image, outputs=im)
127
- im.change(predict, inputs=[im, labels], outputs=cf)
128
- reclass_btn.click(predict, inputs=[im, labels], outputs=cf)
129
 
130
  with gr.Tab("Image Captioning"):
131
  with gr.Row():
 
9
  import requests
10
  import torch
11
 
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # # Load the pre-trained model and processor
15
+ # model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
16
+ # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
17
+
18
+ orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
19
 
 
 
 
20
 
21
  # Load the Unsplash dataset
22
  dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
 
30
  probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
31
  return {k: float(v) for k, v in zip(labels, probs[0])}
32
 
33
+
34
+ def predict2(image, labels):
35
+ image = orig_clip_processor(img).unsqueeze(0).to(device)
36
+ text = clip.tokenize(labels).to(device)
37
+ with torch.no_grad():
38
+ image_features = orig_clip_model.encode_image(image)
39
+ text_features = orig_clip_model.encode_text(text)
40
+ logits_per_image, logits_per_text = orig_clip_model(image, text)
41
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
42
+ return {k: float(v) for k, v in zip(labels, probs[0])}
43
+
44
  def rand_image():
45
  n = dataset.num_rows
46
  r = random.randrange(0,n)
 
61
  with open(emb_filename, 'rb') as emb:
62
  id2url, img_names, img_emb = pickle.load(emb)
63
 
 
64
 
65
  def search(search_query):
66
 
 
136
  label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
137
  label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
138
  get_btn.click(fn=rand_image, outputs=im)
139
+ im.change(predict2, inputs=[im, labels], outputs=cf)
140
+ reclass_btn.click(predict2, inputs=[im, labels], outputs=cf)
141
 
142
  with gr.Tab("Image Captioning"):
143
  with gr.Row():