taesiri commited on
Commit
41b72e9
1 Parent(s): feccca2

initial-commit

Browse files
Files changed (3) hide show
  1. SimSearch.py +46 -0
  2. app.py +153 -0
  3. requirement.txt +8 -0
SimSearch.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+
4
+ class FaissNeighbors:
5
+ def __init__(self):
6
+ self.index = None
7
+ self.y = None
8
+
9
+ def fit(self, X, y):
10
+ self.index = faiss.IndexFlatL2(X.shape[1])
11
+ self.index.add(X.astype(np.float32))
12
+ self.y = y
13
+
14
+ def get_distances_and_indices(self, X, top_K=1000):
15
+ distances, indices = self.index.search(X.astype(np.float32), k=top_K)
16
+ return np.copy(distances), np.copy(indices), np.copy(self.y[indices])
17
+
18
+ def get_nearest_labels(self, X, top_K=1000):
19
+ distances, indices = self.index.search(X.astype(np.float32), k=top_K)
20
+ return np.copy(self.y[indices])
21
+
22
+
23
+ class FaissCosineNeighbors:
24
+ def __init__(self):
25
+ self.cindex = None
26
+ self.y = None
27
+
28
+ def fit(self, X, y):
29
+ self.cindex = faiss.index_factory(X.shape[1], "Flat", faiss.METRIC_INNER_PRODUCT)
30
+ X = np.copy(X)
31
+ X = X.astype(np.float32)
32
+ faiss.normalize_L2(X)
33
+ self.cindex.add(X)
34
+ self.y = y
35
+
36
+ def get_distances_and_indices(self, Q, topK):
37
+ Q = np.copy(Q)
38
+ faiss.normalize_L2(Q)
39
+ distances, indices = self.cindex.search(Q.astype(np.float32), k=topK)
40
+ return np.copy(distances), np.copy(indices), np.copy(self.y[indices])
41
+
42
+ def get_nearest_labels(self, Q, topK=1000):
43
+ Q = np.copy(Q)
44
+ faiss.normalize_L2(Q)
45
+ distances, indices = self.cindex.search(Q.astype(np.float32), k=topK)
46
+ return np.copy(self.y[indices])
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+ import clip
5
+ from glob import glob
6
+ import gradio as gr
7
+ import os
8
+ import torchvision
9
+ import pickle
10
+ from collections import Counter
11
+
12
+ from SimSearch import FaissCosineNeighbors
13
+
14
+ # HELPERS
15
+ to_np = lambda x: x.data.to('cpu').numpy()
16
+
17
+ # DOWNLOAD THE DATASET and Files
18
+
19
+ torchvision.datasets.utils.download_file_from_google_drive('1kB1vNdVaNS1OGZ3K8BspBUKkPACCsnrG', '.', 'GTAV-Videos.zip')
20
+ torchvision.datasets.utils.download_file_from_google_drive('1pgvIBTs_6h23wIU28EdqO5y2T1wUfOak', '.', 'GTAV-embedding-vit32.zip')
21
+
22
+ # EXTRACT
23
+ torchvision.datasets.utils.extract_archive(from_path='GTAV-embedding-vit32.zip', to_path='Embeddings/VIT32/', remove_finished=False)
24
+ torchvision.datasets.utils.extract_archive(from_path='GTAV-Videos.zip', to_path='Videos/', remove_finished=False)
25
+
26
+ # Initialize CLIP model
27
+ clip.available_models()
28
+
29
+ # # Searcher
30
+
31
+ class GamePhysicsSearcher:
32
+ def __init__(self, CLIP_MODEL, GAME_NAME, EMBEDDING_PATH='./Embeddings/VIT32/'):
33
+ self.CLIP_MODEL = CLIP_MODEL
34
+ self.GAME_NAME = GAME_NAME
35
+ self.simsearcher = FaissCosineNeighbors()
36
+
37
+ self.all_embeddings = glob(f'{EMBEDDING_PATH}{self.GAME_NAME}/*.npy')
38
+
39
+ self.filenames = [os.path.basename(x) for x in self.all_embeddings]
40
+ self.file_to_class_id = {x:i for i, x in enumerate(self.filenames)}
41
+ self.class_id_to_file = {i:x for i, x in enumerate(self.filenames)}
42
+ self.build_index()
43
+
44
+ def read_features(self, file_path):
45
+ with open(file_path, 'rb') as f:
46
+ video_features = pickle.load(f)
47
+ return video_features
48
+
49
+ def read_all_features(self):
50
+ features = {}
51
+ filenames_extended = []
52
+
53
+ X_train = []
54
+ y_train = []
55
+
56
+ for i, vfile in enumerate(tqdm(self.all_embeddings)):
57
+ vfeatures = to_np(self.read_features(vfile))
58
+ features[vfile.split('/')[-1]] = vfeatures
59
+ X_train.extend(vfeatures)
60
+ y_train.extend([i]*vfeatures.shape[0])
61
+ filenames_extended.extend(vfeatures.shape[0]*[vfile.split('/')[-1]])
62
+
63
+ X_train = np.asarray(X_train)
64
+ y_train = np.asarray(y_train)
65
+
66
+ return X_train, y_train
67
+
68
+ def build_index(self):
69
+ X_train, y_train = self.read_all_features()
70
+ self.simsearcher.fit(X_train, y_train)
71
+
72
+ def text_to_vector(self, query):
73
+ text_tokens = clip.tokenize(query).cuda()
74
+ with torch.no_grad():
75
+ text_features = self.CLIP_MODEL.encode_text(text_tokens).float()
76
+ text_features /= text_features.norm(dim=-1, keepdim=True)
77
+ return to_np(text_features)
78
+
79
+ # Source: https://stackoverflow.com/a/480227
80
+ def f7(self, seq):
81
+ seen = set()
82
+ seen_add = seen.add # This is for performance improvement, don't remove
83
+ return [x for x in seq if not (x in seen or seen_add(x))]
84
+
85
+ def search_top_k(self, q, k=5, pool_size=1000, search_mod='Majority'):
86
+ q = self.text_to_vector(q)
87
+ nearest_data_points = self.simsearcher.get_nearest_labels(q, pool_size)
88
+
89
+ if search_mod == 'Majority':
90
+ topKs = [x[0] for x in Counter(nearest_data_points[0]).most_common(k)]
91
+ elif search_mod == 'Top-K':
92
+ topKs = list(self.f7(nearest_data_points[0]))[:k]
93
+
94
+ video_filename = [f'./Videos/{self.GAME_NAME}/' + self.class_id_to_file[x].replace('npy', 'mp4') for x in topKs]
95
+
96
+ return video_filename
97
+
98
+
99
+
100
+ ################ SEARCH CORE ################
101
+ # CRAETE CLIP MODEL
102
+ vit_model, vit_preprocess = clip.load("ViT-B/32")
103
+ vit_model.cuda().eval()
104
+
105
+ saved_searchers = {}
106
+ def gradio_search(query, game_name, selected_model, aggregator, pool_size, k=6):
107
+ # print(query, game_name, selected_model, aggregator, pool_size)
108
+ if f'{game_name}_{selected_model}' in saved_searchers.keys():
109
+ searcher = saved_searchers[f'{game_name}_{selected_model}']
110
+ else:
111
+ if selected_model == 'ViT-B/32':
112
+ model = vit_model
113
+ searcher = GamePhysicsSearcher(CLIP_MODEL=model, GAME_NAME=game_name)
114
+ else:
115
+ raise
116
+
117
+ saved_searchers[f'{game_name}_{selected_model}'] = searcher
118
+
119
+ results = []
120
+ relevant_videos = searcher.search_top_k(query, k=k, pool_size=pool_size, search_mod=aggregator)
121
+ params = ', '.join(map(str, [query, game_name, selected_model, aggregator, pool_size]))
122
+ results.append(params)
123
+ results.extend(relevant_videos)
124
+ print(results)
125
+ return results
126
+
127
+ list_of_games = ['Grand Theft Auto V']
128
+
129
+
130
+
131
+
132
+
133
+ # GRADIO APP
134
+ iface = gr.Interface(fn=gradio_search,
135
+ inputs =[ gr.inputs.Textbox(lines=1, placeholder='Search Query', default="A man in the air", label=None),
136
+ gr.inputs.Radio(list_of_games, label="Game To Search"),
137
+ gr.inputs.Radio(['ViT-B/32'], label="MODEL"),
138
+ gr.inputs.Radio(['Majority', 'Top-K'], label="Aggregator"),
139
+ gr.inputs.Slider(300, 2000, label="Pool Size"),
140
+ ],
141
+ outputs=[
142
+ gr.outputs.Textbox(type="auto", label='Search Params'),
143
+ gr.outputs.Video(type='mp4', label='Result 1'),
144
+ gr.outputs.Video(type='mp4', label='Result 2'),
145
+ gr.outputs.Video(type='mp4', label='Result 3'),
146
+ gr.outputs.Video(type='mp4', label='Result 4'),
147
+ gr.outputs.Video(type='mp4', label='Result 5')],
148
+ server_port=7878,
149
+ server_name="0.0.0.0",
150
+ # examples=[],
151
+ title='CLIP Meets Game Physics Demo'
152
+ )
153
+ iface.launch()
requirement.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ tqdm
4
+ Pillow
5
+ scikit-image
6
+ gdown
7
+ torchvision
8
+ git+https://github.com/openai/CLIP.git