Spaces:
Runtime error
Runtime error
File size: 4,670 Bytes
a9cbf7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import streamlit as st
from pytube import YouTube
from pytube import extract
import cv2
from PIL import Image
import clip as openai_clip
import torch
import math
import SessionState
from humanfriendly import format_timespan
def fetch_video(url):
yt = YouTube(url)
streams = yt.streams.filter(adaptive=True, subtype="mp4", resolution="360p", only_video=True)
length = yt.length
if length >= 300:
st.error("Please find a YouTube video shorter than 5 minutes. Sorry about this, the server capacity is limited for the time being.")
st.stop()
video = streams[0]
return video, video.url
@st.cache()
def extract_frames(video):
frames = []
capture = cv2.VideoCapture(video)
fps = capture.get(cv2.CAP_PROP_FPS)
current_frame = 0
while capture.isOpened():
ret, frame = capture.read()
if ret == True:
frames.append(Image.fromarray(frame[:, :, ::-1]))
else:
break
current_frame += N
capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
return frames, fps
@st.cache()
def encode_frames(video_frames):
batch_size = 256
batches = math.ceil(len(video_frames) / batch_size)
video_features = torch.empty([0, 512], dtype=torch.float16).to(device)
for i in range(batches):
batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
with torch.no_grad():
batch_features = model.encode_image(batch_preprocessed)
batch_features /= batch_features.norm(dim=-1, keepdim=True)
video_features = torch.cat((video_features, batch_features))
return video_features
def img_to_bytes(img):
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
return img_byte_arr
def display_results(best_photo_idx):
st.markdown("**Top-5 matching results**")
result_arr = []
for frame_id in best_photo_idx:
result = ss.video_frames[frame_id]
st.image(result)
seconds = round(frame_id.cpu().numpy()[0] * N / ss.fps)
result_arr.append(seconds)
time = format_timespan(seconds)
if ss.input == "file":
st.write("Seen at " + str(time) + " into the video.")
else:
st.markdown("Seen at [" + str(time) + "](" + url + "&t=" + str(seconds) + "s) into the video.")
return result_arr
def text_search(search_query, display_results_count=5):
with torch.no_grad():
text_features = model.encode_text(openai_clip.tokenize(search_query).to(device))
text_features /= text_features.norm(dim=-1, keepdim=True)
similarities = (100.0 * ss.video_features @ text_features.T)
values, best_photo_idx = similarities.topk(display_results_count, dim=0)
result_arr = display_results(best_photo_idx)
return result_arr
st.set_page_config(page_title="Which Frame?", page_icon = "🔍", layout = "centered", initial_sidebar_state = "collapsed")
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
* {font-family: Avenir;}
.css-gma2qf {display: flex; justify-content: center; font-size: 42px; font-weight: bold;}
a:link {text-decoration: none;}
a:hover {text-decoration: none;}
.st-ba {font-family: Avenir;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
ss = SessionState.get(url=None, id=None, input=None, file_name=None, video=None, video_name=None, video_frames=None, video_features=None, fps=None, mode=None, query=None, progress=1)
st.title("Which Frame?")
st.markdown("Search a video **semantically**. For example: Which frame has a person with sunglasses and earphones?")
url = st.text_input("Link to a YouTube video (Example: https://www.youtube.com/watch?v=sxaTnm_4YMY)")
N = 30
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = openai_clip.load("ViT-B/32", device=device)
if st.button("Process video (this may take a while)"):
ss.progress = 1
ss.video_start_time = 0
if url:
ss.input = "link"
ss.video, ss.video_name = fetch_video(url)
ss.id = extract.video_id(url)
ss.url = "https://www.youtube.com/watch?v=" + ss.id
else:
st.error("Please upload a video or link to a valid YouTube video")
st.stop()
ss.video_frames, ss.fps = extract_frames(ss.video_name)
ss.video_features = encode_frames(ss.video_frames)
st.video(ss.url)
ss.progress = 2
if ss.progress == 2:
ss.text_query = st.text_input("Enter search query (Example: a person with sunglasses and earphones)")
if st.button("Submit"):
if ss.text_query is not None:
text_search(ss.text_query) |