Spaces:
Build error
Build error
import cv2 | |
from PIL import Image | |
import math | |
import clip | |
import torch | |
from flask import Flask, request, jsonify | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# load OpenAI CLIP model | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
app = Flask(__name__) | |
# define the API endpoint | |
def search_videos(): | |
# get the search item from the request parameters | |
search_item = request.form.get('search_item') | |
# get the video files from the request parameters | |
video_files = request.files.getlist('video_files') | |
# store the matching videos | |
matching_videos = [] | |
# loop through all videos | |
for video_file in video_files: | |
# no. of frames to skip | |
n = 120 | |
# store the video frames | |
video_frames = [] | |
# open the video | |
capture = cv2.VideoCapture(video_file) | |
fps = capture.get(cv2.CAP_PROP_FPS) | |
current_frame = 0 | |
# read the current frame | |
ret, frame = capture.read() | |
while capture.isOpened() and ret: | |
ret,frame = capture.read() | |
if ret: | |
video_frames.append(Image.fromarray(frame[:, :, ::-1])) | |
# skip n frames | |
current_frame += n | |
capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame) | |
# ENCODE THE FRAMES | |
batch_size = 256 | |
batches = math.ceil(len(video_frames) / batch_size) | |
# store the encoded features | |
video_features = torch.empty([0, 512], dtype=torch.float16).to(device) | |
# process each batch | |
for i in range(batches): | |
# get the relevant frames | |
batch_frames = video_frames[i*batch_size : (i+1)*batch_size] | |
# preprocess the frames for the batch | |
batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device) | |
# encode with CLIP and normalize | |
with torch.no_grad(): | |
batch_features = model.encode_image(batch_preprocessed) | |
batch_features /= batch_features.norm(dim=-1, keepdim=True) | |
# append the batch to the list containing all features | |
video_features = torch.cat((video_features, batch_features)) | |
# determine if video contains the search item | |
if contain_search_item(video_frames, video_features, search_item): | |
matching_videos.append(video_file) | |
# break | |
# return the list of matching videos | |
return matching_videos | |
def contain_search_item(video_frames, video_features, search_query): | |
# encode and normalize the search query using CLIP | |
with torch.no_grad(): | |
text_features = model.encode_text(clip.tokenize(search_query).to(device)) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
# compute the similarity between the search query and each frame | |
similarities = (100.0 * video_features @ text_features.T) | |
values, best_photo_idx = similarities.topk(1, dim=0) | |
for frame_id in best_photo_idx: | |
frame = video_frames[frame_id] | |
return frame | |
if __name__ == '__main__': | |
app.run(debug=True) | |