import os import shutil import zipfile from os.path import join, isfile, basename import cv2 import numpy as np import gradio as gr from gradio.components import Video, Number, File import torch from resnet50 import resnet18 from sampling_util import furthest_neighbours from video_reader import video_reader model = resnet18( output_dim=0, nmb_prototypes=0, eval_mode=True, hidden_mlp=0, normalize=False) model.load_state_dict(torch.load("model.pt")) model.eval() avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) def predict(input_file, downsample_size): downsample_size = int(downsample_size) base_directory = os.getcwd() selected_directory = os.path.join(base_directory, "selected_images") if os.path.isdir(selected_directory): shutil.rmtree(selected_directory) os.mkdir(selected_directory) file_name = (input_file.split('/')[-1]).split('.')[-1] zip_path = os.path.join(selected_directory, file_name + ".zip") mean = np.asarray([0.3156024, 0.33569682, 0.34337464], dtype=np.float32) std = np.asarray([0.16568947, 0.17827448, 0.18925823], dtype=np.float32) img_vecs = [] with torch.no_grad(): for fp_i, file_path in enumerate([input_file]): for i, in_img in enumerate(video_reader(file_path, targetFPS=9, targetWidth=100, to_rgb=True)): in_img = (in_img.astype(np.float32) / 255.) in_img = (in_img - mean) / std in_img = np.expand_dims(in_img, 0) in_img = np.transpose(in_img, (0, 3, 1, 2)) in_img = torch.from_numpy(in_img).float() encoded = avg_pool(model(in_img))[0, :, 0, 0].cpu().numpy() img_vecs += [encoded] img_vecs = np.asarray(img_vecs) print("images encoded") rv_indices, _ = furthest_neighbours( x=img_vecs, downsample_size=downsample_size, seed=0) indices = np.zeros((img_vecs.shape[0],)) indices[np.asarray(rv_indices)] = 1 print("images selected") global_ctr = 0 for fp_i, file_path in enumerate([input_file]): for i, img in enumerate(video_reader(file_path, targetFPS=9, targetWidth=None, to_rgb=False)): if indices[global_ctr] == 1: cv2.imwrite(join(selected_directory, str(global_ctr) + ".jpg"), img) global_ctr += 1 print("selected images extracted") all_selected_imgs_path = [join(selected_directory, f) for f in os.listdir(selected_directory) if isfile(join(selected_directory, f))] zipf = zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) for i, f in enumerate(all_selected_imgs_path): zipf.write(f, basename(f)) zipf.close() print("selected images zipped") return zip_path demo = gr.Interface( enable_queue=True, title="Frame selection by visual difference", description="", fn=predict, inputs=[Video(label="Upload Video File"), Number(label="Downsample size")], outputs=File(label="Zip"), ) demo.launch(enable_queue=True)