Diangle commited on
Commit
256a58e
·
1 Parent(s): b1db48b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -34
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import gradio
2
  import os
3
  import numpy as np
4
  import pandas as pd
@@ -8,12 +8,17 @@ import torch
8
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
9
 
10
 
 
 
 
 
 
11
  DATA_PATH = './data'
12
 
13
  ft_visual_features_file = DATA_PATH + '/dataset_v1_visual_features_database.npy'
14
- binary_visual_features_file = DATA_PATH + '/dataset_v1_visual_features_database_packed.npy'
 
15
  ft_visual_features_database = np.load(ft_visual_features_file)
16
- binary_visual_features = np.load(binary_visual_features_file)
17
 
18
  database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
19
  database_df = pd.read_csv(database_csv_path)
@@ -38,7 +43,7 @@ class NearestNeighbors:
38
  def fit(self, data, o_data=None):
39
  if self.metric == 'cosine':
40
  data = self.normalize(data)
41
- self.index = faiss.IndexFlatIP(data.shape[1])
42
  elif self.metric == 'binary':
43
  self.o_data = data if o_data is None else o_data
44
  #assuming data already packed
@@ -47,44 +52,37 @@ class NearestNeighbors:
47
 
48
  def kneighbors(self, q_data):
49
  if self.metric == 'cosine':
50
- print('cosine search')
51
- q_data = self.normalize(q_data)
52
  sim, idx = self.index.search(q_data, self.n_neighbors)
53
  else:
54
  if self.metric == 'binary':
55
- print('binary search')
56
  bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)
57
  print(bq_data.shape, self.index.d)
58
  sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
59
 
60
  if self.rerank_from > self.n_neighbors:
61
- sim_float = np.zeros([len(q_data), self.rerank_from], dtype=float)
62
- for i, q in enumerate(q_data):
63
- candidates = np.take_along_axis(self.o_data, idx[i:i+1,:].T, axis=0)
64
- sim_float[i,:] = q @ candidates.T
65
- sort_idx = np.argsort(sim_float[i,:])[::-1]
66
- sim_float[i,:] = sim_float[i,:][sort_idx]
67
- idx[i,:] = idx[i,:][sort_idx]
68
- sim = sim_float[:,:self.n_neighbors]
69
- idx = idx[:,:self.n_neighbors]
70
 
71
  return sim, idx
72
 
73
-
 
 
74
  def search(search_sentence):
75
- my_model = CLIPTextModelWithProjection.from_pretrained("Diangle/clip4clip-webvid")
76
- tokenizer = AutoTokenizer.from_pretrained("Diangle/clip4clip-webvid")
77
-
78
-
79
  inputs = tokenizer(text=search_sentence , return_tensors="pt", padding=True)
80
 
81
- outputs = my_model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict=False)
82
-
83
- text_projection = my_model.state_dict()['text_projection.weight']
84
  text_embeds = outputs[1] @ text_projection
85
  final_output = text_embeds[torch.arange(text_embeds.shape[0]), inputs["input_ids"].argmax(dim=-1)]
86
 
87
-
88
  final_output = final_output / final_output.norm(dim=-1, keepdim=True)
89
  final_output = final_output.cpu().detach().numpy()
90
  sequence_output = final_output / np.sum(final_output**2, axis=1, keepdims=True)
@@ -94,12 +92,22 @@ def search(search_sentence):
94
  sims, idxs = nn_search.kneighbors(sequence_output)
95
  return database_df.iloc[idxs[0]]['contentUrl'].to_list()
96
 
97
-
98
- gradio.close_all()
99
-
100
- interface = gradio.Interface(search,
101
- inputs=[gradio.Textbox()],
102
- outputs=[gradio.Video(format='mp4') for _ in range(5)],
103
- title = 'Video Search Demo',
104
- description = 'Type some text to search by content within a video database!',
105
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import os
3
  import numpy as np
4
  import pandas as pd
 
8
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
9
 
10
 
11
+ TITLE="""<h1 style="font-size: 42px;" align="center">Video Retrieval</h1>"""
12
+
13
+ DESCRIPTION="""This is a video retrieval demo using [Diangle/clip4clip-webvid](https://huggingface.co/Diangle/clip4clip-webvid)."""
14
+ IMAGE='<img src="./Searchium.png"/>'
15
+
16
  DATA_PATH = './data'
17
 
18
  ft_visual_features_file = DATA_PATH + '/dataset_v1_visual_features_database.npy'
19
+
20
+ #load database features:
21
  ft_visual_features_database = np.load(ft_visual_features_file)
 
22
 
23
  database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
24
  database_df = pd.read_csv(database_csv_path)
 
43
  def fit(self, data, o_data=None):
44
  if self.metric == 'cosine':
45
  data = self.normalize(data)
46
+ self.index = faiss.IndexFlatIP(data.shape[1])
47
  elif self.metric == 'binary':
48
  self.o_data = data if o_data is None else o_data
49
  #assuming data already packed
 
52
 
53
  def kneighbors(self, q_data):
54
  if self.metric == 'cosine':
55
+ q_data = self.normalize(q_data)
 
56
  sim, idx = self.index.search(q_data, self.n_neighbors)
57
  else:
58
  if self.metric == 'binary':
59
+ print('binary search: ')
60
  bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)
61
  print(bq_data.shape, self.index.d)
62
  sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
63
 
64
  if self.rerank_from > self.n_neighbors:
65
+ rerank_data = self.o_data[idx[0]]
66
+ rerank_search = NearestNeighbors(n_neighbors=self.n_neighbors, metric='cosine')
67
+ rerank_search.fit(rerank_data)
68
+ sim, re_idxs = rerank_search.kneighbors(q_data)
69
+ idx = [idx[0][re_idxs[0]]]
 
 
 
 
70
 
71
  return sim, idx
72
 
73
+ model = CLIPTextModelWithProjection.from_pretrained("Diangle/clip4clip-webvid")
74
+ tokenizer = AutoTokenizer.from_pretrained("Diangle/clip4clip-webvid")
75
+
76
  def search(search_sentence):
 
 
 
 
77
  inputs = tokenizer(text=search_sentence , return_tensors="pt", padding=True)
78
 
79
+ outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict=False)
80
+ # Customized projection layer
81
+ text_projection = model.state_dict()['text_projection.weight']
82
  text_embeds = outputs[1] @ text_projection
83
  final_output = text_embeds[torch.arange(text_embeds.shape[0]), inputs["input_ids"].argmax(dim=-1)]
84
 
85
+ # Normalization
86
  final_output = final_output / final_output.norm(dim=-1, keepdim=True)
87
  final_output = final_output.cpu().detach().numpy()
88
  sequence_output = final_output / np.sum(final_output**2, axis=1, keepdims=True)
 
92
  sims, idxs = nn_search.kneighbors(sequence_output)
93
  return database_df.iloc[idxs[0]]['contentUrl'].to_list()
94
 
95
+
96
+ with gr.Blocks() as demo:
97
+ gr.HTML(TITLE)
98
+ gr.Markdown(DESCRIPTION)
99
+ gr.HTML(IMAGE)
100
+ gr.Markdown("Retrieval of top 5 videos relevant to the input sentence: ")
101
+ with gr.Row():
102
+ with gr.Column():
103
+ inp = gr.Textbox(placeholder="Write a sentence.")
104
+ btn = gr.Button(value="Retrieve")
105
+ ex = [["a woman waving to the camera"],["a basketball player performing a slam dunk"], ["how to bake a chocolate cake"], ["birds fly in the sky"]]
106
+ gr.Examples(examples=ex,
107
+ inputs=[inp],
108
+ )
109
+ with gr.Column():
110
+ out = [gr.Video(format='mp4') for _ in range(5)]
111
+ btn.click(search, inputs=inp, outputs=out)
112
+
113
+ demo.launch(debug=True, share=True)