Ahsen Khaliq
Update app.py
bd30780
raw history blame
No virus
3.46 kB
import cv2
from PIL import Image
import clip
import torch
import math
import numpy as np
import torch
import datetime
import gradio as gr
# Load the open CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
def search_video(search_query, display_heatmap=True, display_results_count=1):
# Encode and normalize the search query using CLIP
with torch.no_grad():
text_features = model.encode_text(clip.tokenize(search_query).to(device))
text_features /= text_features.norm(dim=-1, keepdim=True)
# Compute the similarity between the search query and each frame using the Cosine similarity
similarities = (100.0 * video_features @ text_features.T)
values, best_photo_idx = similarities.topk(display_results_count, dim=0)
for frame_id in best_photo_idx:
frame = video_frames[frame_id]
# Find the timestamp in the video and display it
seconds = round(frame_id.cpu().numpy()[0] * N / fps)
return frame,f"Found at {str(datetime.timedelta(seconds=seconds))}"
def inference(video, text):
# The frame images will be stored in video_frames
video_frames = []
# Open the video file
capture = cv2.VideoCapture(video)
fps = capture.get(cv2.CAP_PROP_FPS)
current_frame = 0
while capture.isOpened():
# Read the current frame
ret, frame = capture.read()
# Convert it to a PIL image (required for CLIP) and store it
if ret == True:
video_frames.append(Image.fromarray(frame[:, :, ::-1]))
else:
break
# Skip N frames
current_frame += N
capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
# Print some statistics
print(f"Frames extracted: {len(video_frames)}")
# You can try tuning the batch size for very large videos, but it should usually be OK
batch_size = 256
batches = math.ceil(len(video_frames) / batch_size)
# The encoded features will bs stored in video_features
video_features = torch.empty([0, 512], dtype=torch.float16).to(device)
# Process each batch
for i in range(batches):
print(f"Processing batch {i+1}/{batches}")
# Get the relevant frames
batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
# Preprocess the images for the batch
batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
# Encode with CLIP and normalize
with torch.no_grad():
batch_features = model.encode_image(batch_preprocessed)
batch_features /= batch_features.norm(dim=-1, keepdim=True)
# Append the batch to the list containing all features
video_features = torch.cat((video_features, batch_features))
# Print some stats
print(f"Features: {video_features.shape}")
return search_video(text)
title = "Video Search"
description = "demo for Anime2Sketch. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.05703'>Adversarial Open Domain Adaption for Sketch-to-Photo Synthesis</a> | <a href='https://github.com/Mukosame/Anime2Sketch'>Github Repo</a></p>"
gr.Interface(
inference,
["video","text"],
[gr.outputs.Image(type="pil", label="Output"),"text"],
title=title,
description=description,
article=article,
enable_queue=True
).launch(debug=True)