Spaces:
Runtime error
Runtime error
File size: 3,277 Bytes
c9ec901 bd30780 c9ec901 9e60d26 43d1c17 c9ec901 bd30780 c9ec901 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import cv2
from PIL import Image
import clip
import torch
import math
import numpy as np
import torch
import datetime
import gradio as gr
# Load the open CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
def search_video(search_query, display_heatmap=True, display_results_count=1):
# Encode and normalize the search query using CLIP
with torch.no_grad():
text_features = model.encode_text(clip.tokenize(search_query).to(device))
text_features /= text_features.norm(dim=-1, keepdim=True)
# Compute the similarity between the search query and each frame using the Cosine similarity
similarities = (100.0 * video_features @ text_features.T)
values, best_photo_idx = similarities.topk(display_results_count, dim=0)
for frame_id in best_photo_idx:
frame = video_frames[frame_id]
# Find the timestamp in the video and display it
seconds = round(frame_id.cpu().numpy()[0] * N / fps)
return frame,f"Found at {str(datetime.timedelta(seconds=seconds))}"
def inference(video, text):
# The frame images will be stored in video_frames
video_frames = []
# Open the video file
capture = cv2.VideoCapture(video)
fps = capture.get(cv2.CAP_PROP_FPS)
current_frame = 0
# Read the current frame
ret, frame = capture.read()
# Convert it to a PIL image (required for CLIP) and store it
video_frames.append(Image.fromarray(frame[:, :, ::-1]))
# Print some statistics
print(f"Frames extracted: {len(video_frames)}")
# You can try tuning the batch size for very large videos, but it should usually be OK
batch_size = 256
batches = math.ceil(len(video_frames) / batch_size)
# The encoded features will bs stored in video_features
video_features = torch.empty([0, 512], dtype=torch.float16).to(device)
# Process each batch
for i in range(batches):
print(f"Processing batch {i+1}/{batches}")
# Get the relevant frames
batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
# Preprocess the images for the batch
batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
# Encode with CLIP and normalize
with torch.no_grad():
batch_features = model.encode_image(batch_preprocessed)
batch_features /= batch_features.norm(dim=-1, keepdim=True)
# Append the batch to the list containing all features
video_features = torch.cat((video_features, batch_features))
# Print some stats
print(f"Features: {video_features.shape}")
return search_video(text)
title = "Video Search"
description = "demo for Anime2Sketch. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.05703'>Adversarial Open Domain Adaption for Sketch-to-Photo Synthesis</a> | <a href='https://github.com/Mukosame/Anime2Sketch'>Github Repo</a></p>"
gr.Interface(
inference,
["video","text"],
[gr.outputs.Image(type="pil", label="Output"),"text"],
title=title,
description=description,
article=article,
enable_queue=True
).launch(debug=True)
|