|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
import clip |
|
from glob import glob |
|
import gradio as gr |
|
import os |
|
import torchvision |
|
import pickle |
|
from collections import Counter |
|
|
|
from SimSearch import FaissCosineNeighbors |
|
|
|
|
|
to_np = lambda x: x.data.to('cpu').numpy() |
|
|
|
|
|
|
|
torchvision.datasets.utils.download_file_from_google_drive('1kB1vNdVaNS1OGZ3K8BspBUKkPACCsnrG', '.', 'GTAV-Videos.zip') |
|
torchvision.datasets.utils.download_file_from_google_drive('1pgvIBTs_6h23wIU28EdqO5y2T1wUfOak', '.', 'GTAV-embedding-vit32.zip') |
|
|
|
|
|
torchvision.datasets.utils.extract_archive(from_path='GTAV-embedding-vit32.zip', to_path='Embeddings/VIT32/', remove_finished=False) |
|
torchvision.datasets.utils.extract_archive(from_path='GTAV-Videos.zip', to_path='Videos/', remove_finished=False) |
|
|
|
|
|
clip.available_models() |
|
|
|
|
|
|
|
class GamePhysicsSearcher: |
|
def __init__(self, CLIP_MODEL, GAME_NAME, EMBEDDING_PATH='./Embeddings/VIT32/'): |
|
self.CLIP_MODEL = CLIP_MODEL |
|
self.GAME_NAME = GAME_NAME |
|
self.simsearcher = FaissCosineNeighbors() |
|
|
|
self.all_embeddings = glob(f'{EMBEDDING_PATH}{self.GAME_NAME}/*.npy') |
|
|
|
self.filenames = [os.path.basename(x) for x in self.all_embeddings] |
|
self.file_to_class_id = {x:i for i, x in enumerate(self.filenames)} |
|
self.class_id_to_file = {i:x for i, x in enumerate(self.filenames)} |
|
self.build_index() |
|
|
|
def read_features(self, file_path): |
|
with open(file_path, 'rb') as f: |
|
video_features = pickle.load(f) |
|
return video_features |
|
|
|
def read_all_features(self): |
|
features = {} |
|
filenames_extended = [] |
|
|
|
X_train = [] |
|
y_train = [] |
|
|
|
for i, vfile in enumerate(tqdm(self.all_embeddings)): |
|
vfeatures = to_np(self.read_features(vfile)) |
|
features[vfile.split('/')[-1]] = vfeatures |
|
X_train.extend(vfeatures) |
|
y_train.extend([i]*vfeatures.shape[0]) |
|
filenames_extended.extend(vfeatures.shape[0]*[vfile.split('/')[-1]]) |
|
|
|
X_train = np.asarray(X_train) |
|
y_train = np.asarray(y_train) |
|
|
|
return X_train, y_train |
|
|
|
def build_index(self): |
|
X_train, y_train = self.read_all_features() |
|
self.simsearcher.fit(X_train, y_train) |
|
|
|
def text_to_vector(self, query): |
|
text_tokens = clip.tokenize(query).cuda() |
|
with torch.no_grad(): |
|
text_features = self.CLIP_MODEL.encode_text(text_tokens).float() |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
return to_np(text_features) |
|
|
|
|
|
def f7(self, seq): |
|
seen = set() |
|
seen_add = seen.add |
|
return [x for x in seq if not (x in seen or seen_add(x))] |
|
|
|
def search_top_k(self, q, k=5, pool_size=1000, search_mod='Majority'): |
|
q = self.text_to_vector(q) |
|
nearest_data_points = self.simsearcher.get_nearest_labels(q, pool_size) |
|
|
|
if search_mod == 'Majority': |
|
topKs = [x[0] for x in Counter(nearest_data_points[0]).most_common(k)] |
|
elif search_mod == 'Top-K': |
|
topKs = list(self.f7(nearest_data_points[0]))[:k] |
|
|
|
video_filename = [f'./Videos/{self.GAME_NAME}/' + self.class_id_to_file[x].replace('npy', 'mp4') for x in topKs] |
|
|
|
return video_filename |
|
|
|
|
|
|
|
|
|
|
|
vit_model, vit_preprocess = clip.load("ViT-B/32") |
|
vit_model.cuda().eval() |
|
|
|
saved_searchers = {} |
|
def gradio_search(query, game_name, selected_model, aggregator, pool_size, k=6): |
|
|
|
if f'{game_name}_{selected_model}' in saved_searchers.keys(): |
|
searcher = saved_searchers[f'{game_name}_{selected_model}'] |
|
else: |
|
if selected_model == 'ViT-B/32': |
|
model = vit_model |
|
searcher = GamePhysicsSearcher(CLIP_MODEL=model, GAME_NAME=game_name) |
|
else: |
|
raise |
|
|
|
saved_searchers[f'{game_name}_{selected_model}'] = searcher |
|
|
|
results = [] |
|
relevant_videos = searcher.search_top_k(query, k=k, pool_size=pool_size, search_mod=aggregator) |
|
params = ', '.join(map(str, [query, game_name, selected_model, aggregator, pool_size])) |
|
results.append(params) |
|
results.extend(relevant_videos) |
|
print(results) |
|
return results |
|
|
|
list_of_games = ['Grand Theft Auto V'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface(fn=gradio_search, |
|
inputs =[ gr.inputs.Textbox(lines=1, placeholder='Search Query', default="A man in the air", label=None), |
|
gr.inputs.Radio(list_of_games, label="Game To Search"), |
|
gr.inputs.Radio(['ViT-B/32'], label="MODEL"), |
|
gr.inputs.Radio(['Majority', 'Top-K'], label="Aggregator"), |
|
gr.inputs.Slider(300, 2000, label="Pool Size"), |
|
], |
|
outputs=[ |
|
gr.outputs.Textbox(type="auto", label='Search Params'), |
|
gr.outputs.Video(type='mp4', label='Result 1'), |
|
gr.outputs.Video(type='mp4', label='Result 2'), |
|
gr.outputs.Video(type='mp4', label='Result 3'), |
|
gr.outputs.Video(type='mp4', label='Result 4'), |
|
gr.outputs.Video(type='mp4', label='Result 5')], |
|
server_port=7878, |
|
server_name="0.0.0.0", |
|
|
|
title='CLIP Meets Game Physics Demo' |
|
) |
|
iface.launch() |