dashtoon / app.py
japanese-denim's picture
Update app.py
171322c
import cv2
from PIL import Image
import math
import clip
import torch
from flask import Flask, request, jsonify
device = "cuda" if torch.cuda.is_available() else "cpu"
# load OpenAI CLIP model
model, preprocess = clip.load("ViT-B/32", device=device)
app = Flask(__name__)
# define the API endpoint
@app.route('/search_videos', methods=['POST'])
def search_videos():
# get the search item from the request parameters
search_item = request.form.get('search_item')
# get the video files from the request parameters
video_files = request.files.getlist('video_files')
# store the matching videos
matching_videos = []
# loop through all videos
for video_file in video_files:
# no. of frames to skip
n = 120
# store the video frames
video_frames = []
# open the video
capture = cv2.VideoCapture(video_file)
fps = capture.get(cv2.CAP_PROP_FPS)
current_frame = 0
# read the current frame
ret, frame = capture.read()
while capture.isOpened() and ret:
ret,frame = capture.read()
if ret:
video_frames.append(Image.fromarray(frame[:, :, ::-1]))
# skip n frames
current_frame += n
capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
# ENCODE THE FRAMES
batch_size = 256
batches = math.ceil(len(video_frames) / batch_size)
# store the encoded features
video_features = torch.empty([0, 512], dtype=torch.float16).to(device)
# process each batch
for i in range(batches):
# get the relevant frames
batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
# preprocess the frames for the batch
batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
# encode with CLIP and normalize
with torch.no_grad():
batch_features = model.encode_image(batch_preprocessed)
batch_features /= batch_features.norm(dim=-1, keepdim=True)
# append the batch to the list containing all features
video_features = torch.cat((video_features, batch_features))
# determine if video contains the search item
if contain_search_item(video_frames, video_features, search_item):
matching_videos.append(video_file)
# break
# return the list of matching videos
return matching_videos
def contain_search_item(video_frames, video_features, search_query):
# encode and normalize the search query using CLIP
with torch.no_grad():
text_features = model.encode_text(clip.tokenize(search_query).to(device))
text_features /= text_features.norm(dim=-1, keepdim=True)
# compute the similarity between the search query and each frame
similarities = (100.0 * video_features @ text_features.T)
values, best_photo_idx = similarities.topk(1, dim=0)
for frame_id in best_photo_idx:
frame = video_frames[frame_id]
return frame
if __name__ == '__main__':
app.run(debug=True)