File size: 5,548 Bytes
41b72e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import numpy as np
import torch
from tqdm import tqdm
import clip
from glob import glob
import gradio as gr
import os
import torchvision
import pickle
from collections import Counter

from SimSearch import FaissCosineNeighbors

# HELPERS
to_np = lambda x: x.data.to('cpu').numpy()

# DOWNLOAD THE DATASET and Files

torchvision.datasets.utils.download_file_from_google_drive('1kB1vNdVaNS1OGZ3K8BspBUKkPACCsnrG', '.', 'GTAV-Videos.zip')
torchvision.datasets.utils.download_file_from_google_drive('1pgvIBTs_6h23wIU28EdqO5y2T1wUfOak', '.', 'GTAV-embedding-vit32.zip')

# EXTRACT
torchvision.datasets.utils.extract_archive(from_path='GTAV-embedding-vit32.zip', to_path='Embeddings/VIT32/', remove_finished=False)
torchvision.datasets.utils.extract_archive(from_path='GTAV-Videos.zip', to_path='Videos/', remove_finished=False)

# Initialize CLIP model
clip.available_models()

# # Searcher

class GamePhysicsSearcher:
  def __init__(self, CLIP_MODEL, GAME_NAME, EMBEDDING_PATH='./Embeddings/VIT32/'):
    self.CLIP_MODEL = CLIP_MODEL
    self.GAME_NAME = GAME_NAME
    self.simsearcher = FaissCosineNeighbors()

    self.all_embeddings = glob(f'{EMBEDDING_PATH}{self.GAME_NAME}/*.npy')
    
    self.filenames = [os.path.basename(x) for x in self.all_embeddings]
    self.file_to_class_id = {x:i for i, x in enumerate(self.filenames)}
    self.class_id_to_file = {i:x for i, x in enumerate(self.filenames)}
    self.build_index()
    
  def read_features(self, file_path):
    with open(file_path, 'rb') as f:
      video_features = pickle.load(f)
    return video_features
  
  def read_all_features(self):
    features = {}
    filenames_extended = []

    X_train = []
    y_train = []

    for i, vfile in enumerate(tqdm(self.all_embeddings)):
      vfeatures = to_np(self.read_features(vfile))
      features[vfile.split('/')[-1]] = vfeatures
      X_train.extend(vfeatures)
      y_train.extend([i]*vfeatures.shape[0])
      filenames_extended.extend(vfeatures.shape[0]*[vfile.split('/')[-1]])

    X_train = np.asarray(X_train)
    y_train = np.asarray(y_train)
    
    return X_train, y_train
  
  def build_index(self):
    X_train, y_train = self.read_all_features()
    self.simsearcher.fit(X_train, y_train)
    
  def text_to_vector(self, query):
    text_tokens = clip.tokenize(query).cuda()
    with torch.no_grad():
      text_features = self.CLIP_MODEL.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)
    return to_np(text_features)

  # Source: https://stackoverflow.com/a/480227
  def f7(self, seq):
    seen = set()
    seen_add = seen.add # This is for performance improvement, don't remove
    return [x for x in seq if not (x in seen or seen_add(x))]
  
  def search_top_k(self, q, k=5, pool_size=1000, search_mod='Majority'):
    q = self.text_to_vector(q)
    nearest_data_points = self.simsearcher.get_nearest_labels(q, pool_size)
    
    if search_mod == 'Majority':
      topKs = [x[0] for x in Counter(nearest_data_points[0]).most_common(k)]
    elif  search_mod == 'Top-K':
      topKs = list(self.f7(nearest_data_points[0]))[:k]
    
    video_filename = [f'./Videos/{self.GAME_NAME}/' + self.class_id_to_file[x].replace('npy', 'mp4') for x in topKs]
    
    return video_filename



################ SEARCH CORE ################
# CRAETE CLIP MODEL
vit_model, vit_preprocess = clip.load("ViT-B/32")
vit_model.cuda().eval()

saved_searchers = {}
def gradio_search(query, game_name, selected_model, aggregator, pool_size, k=6):
  # print(query, game_name, selected_model, aggregator, pool_size)
  if f'{game_name}_{selected_model}' in saved_searchers.keys():
    searcher = saved_searchers[f'{game_name}_{selected_model}']
  else:
    if selected_model == 'ViT-B/32':
      model = vit_model
      searcher = GamePhysicsSearcher(CLIP_MODEL=model, GAME_NAME=game_name)
    else:
      raise
    
    saved_searchers[f'{game_name}_{selected_model}'] = searcher
  
  results = []
  relevant_videos = searcher.search_top_k(query, k=k, pool_size=pool_size, search_mod=aggregator)
  params = ', '.join(map(str, [query, game_name, selected_model, aggregator, pool_size]))
  results.append(params)
  results.extend(relevant_videos)
  print(results)
  return results

list_of_games = ['Grand Theft Auto V']





# GRADIO APP
iface = gr.Interface(fn=gradio_search, 
                     inputs =[ gr.inputs.Textbox(lines=1, placeholder='Search Query', default="A man in the air", label=None),
                               gr.inputs.Radio(list_of_games, label="Game To Search"),
                               gr.inputs.Radio(['ViT-B/32'], label="MODEL"),
                               gr.inputs.Radio(['Majority', 'Top-K'], label="Aggregator"),
                                gr.inputs.Slider(300, 2000, label="Pool Size"),
                             ], 
                     outputs=[
                              gr.outputs.Textbox(type="auto", label='Search Params'),
                              gr.outputs.Video(type='mp4', label='Result 1'),
                              gr.outputs.Video(type='mp4', label='Result 2'),
                              gr.outputs.Video(type='mp4', label='Result 3'),
                              gr.outputs.Video(type='mp4', label='Result 4'),
                              gr.outputs.Video(type='mp4', label='Result 5')],
                     server_port=7878,
                     server_name="0.0.0.0",
                     # examples=[],
                     title='CLIP Meets Game Physics Demo'
                    )
iface.launch()