whichframe / whichframe.py
David Chuan-En Lin
limit video length 453ac06
1
import streamlit as st
2
from pytube import YouTube
3
from pytube import extract
4
import cv2
5
from PIL import Image
6
import clip as openai_clip
7
import torch
8
import math
9
import numpy as np
10
import SessionState
11
import tempfile
12
from humanfriendly import format_timespan
13
import json
14
import sys
15
from random import randrange
16
import requests
17
18
def fetch_video(url):
19
  yt = YouTube(url)
20
  streams = yt.streams.filter(adaptive=True, subtype="mp4", resolution="360p", only_video=True)
21
  length = yt.length
22
  if length >= 300:
23
    st.error("Please find a YouTube video shorter than 5 minutes. Sorry about this, the server capacity is limited for the time being.")
24
    st.stop()
25
  video = streams[0]
26
  return video, video.url
27
28
@st.cache()
29
def extract_frames(video):
30
  frames = []
31
  capture = cv2.VideoCapture(video)
32
  fps = capture.get(cv2.CAP_PROP_FPS)
33
  current_frame = 0
34
  while capture.isOpened():
35
    ret, frame = capture.read()
36
    if ret == True:
37
      frames.append(Image.fromarray(frame[:, :, ::-1]))
38
    else:
39
      break
40
    current_frame += N
41
    capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
42
  return frames, fps
43
44
@st.cache()
45
def encode_frames(video_frames):
46
  batch_size = 256
47
  batches = math.ceil(len(video_frames) / batch_size)
48
  video_features = torch.empty([0, 512], dtype=torch.float16).to(device)
49
  for i in range(batches):
50
    batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
51
    batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
52
    with torch.no_grad():
53
      batch_features = model.encode_image(batch_preprocessed)
54
      batch_features /= batch_features.norm(dim=-1, keepdim=True)
55
    video_features = torch.cat((video_features, batch_features))
56
  return video_features
57
58
def img_to_bytes(img):
59
  img_byte_arr = io.BytesIO()
60
  img.save(img_byte_arr, format='JPEG')
61
  img_byte_arr = img_byte_arr.getvalue()
62
  return img_byte_arr
63
64
def display_results(best_photo_idx):
65
  st.markdown("**Top-5 matching results**")
66
  result_arr = []
67
  for frame_id in best_photo_idx:
68
    result = ss.video_frames[frame_id]
69
    st.image(result)
70
    seconds = round(frame_id.cpu().numpy()[0] * N / ss.fps)
71
    result_arr.append(seconds)
72
    time = format_timespan(seconds)
73
    if ss.input == "file":
74
      st.write("Seen at " + str(time) + " into the video.")
75
    else:
76
      st.markdown("Seen at [" + str(time) + "](" + url + "&t=" + str(seconds) + "s) into the video.")
77
  return result_arr
78
79
def text_search(search_query, display_results_count=5):
80
  with torch.no_grad():
81
    text_features = model.encode_text(openai_clip.tokenize(search_query).to(device))
82
    text_features /= text_features.norm(dim=-1, keepdim=True)
83
  similarities = (100.0 * ss.video_features @ text_features.T)
84
  values, best_photo_idx = similarities.topk(display_results_count, dim=0)
85
  result_arr = display_results(best_photo_idx)
86
  return result_arr
87
88
def img_search(search_query, display_results_count=5):
89
  with torch.no_grad():
90
    image_features = model.encode_image(preprocess(Image.open(search_query)).unsqueeze(0).to(device))
91
    image_features /= image_features.norm(dim=-1, keepdim=True)
92
  similarities = (100.0 * ss.video_features @ image_features.T)
93
  values, best_photo_idx = similarities.topk(display_results_count, dim=0)
94
  display_results(best_photo_idx)
95
96
def text_and_img_search(text_search_query, image_search_query, display_results_count=5):
97
  with torch.no_grad():
98
    image_features = model.encode_image(preprocess(Image.open(image_search_query)).unsqueeze(0).to(device))
99
    image_features /= image_features.norm(dim=-1, keepdim=True)
100
    text_features = model.encode_text(openai_clip.tokenize(text_search_query).to(device))
101
    text_features /= text_features.norm(dim=-1, keepdim=True)
102
    hybrid_features = image_features + text_features
103
  similarities = (100.0 * ss.video_features @ hybrid_features.T)
104
  values, best_photo_idx = similarities.topk(display_results_count, dim=0)
105
  result_arr = display_results(best_photo_idx)
106
  return result_arr
107
108
st.set_page_config(page_title="Which Frame?", page_icon = "🔍", layout = "centered", initial_sidebar_state = "collapsed")
109
110
hide_streamlit_style = """
111
            <style>
112
            #MainMenu {visibility: hidden;}
113
            footer {visibility: hidden;}
114
            * {font-family: Avenir;}
115
            .css-gma2qf {display: flex; justify-content: center; font-size: 42px; font-weight: bold;}
116
            a:link {text-decoration: none;}
117
            a:hover {text-decoration: none;}
118
            .st-ba {font-family: Avenir;}
119
            </style>
120
            """
121
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
122
123
ss = SessionState.get(url=None, id=None, input=None, file_name=None, video=None, video_name=None, video_frames=None, video_features=None, fps=None, mode=None, query=None, progress=1)
124
125
st.title("Which Frame?")
126
st.markdown("Search a video **semantically**. Which frame has a person with sunglasses and earphones? Try searching with **text**, **image**, or a combined **text + image**.")
127
url = st.text_input("Link to a YouTube video (Example: https://www.youtube.com/watch?v=sxaTnm_4YMY)")
128
129
N = 30
130
131
device = "cuda" if torch.cuda.is_available() else "cpu"
132
model, preprocess = openai_clip.load("ViT-B/32", device=device)
133
134
if st.button("Process video (this may take a while)"):
135
  ss.progress = 1
136
  ss.video_start_time = 0
137
  if url:
138
    ss.input = "link"
139
    ss.video, ss.video_name = fetch_video(url)
140
    ss.id = extract.video_id(url)
141
    ss.url = "https://www.youtube.com/watch?v=" + ss.id
142
  else:
143
    st.error("Please upload a video or link to a valid YouTube video")
144
    st.stop()
145
  ss.video_frames, ss.fps = extract_frames(ss.video_name)
146
  ss.video_features = encode_frames(ss.video_frames)
147
  st.video(ss.url)
148
  ss.progress = 2
149
150
if ss.progress == 2:
151
  ss.mode = st.selectbox("Select a search method (text, image, or text + image)",("Text", "Image", "Text + Image"))
152
  if ss.mode == "Text":
153
    ss.text_query = st.text_input("Enter text query (Example: a person with sunglasses and earphones)")
154
  elif ss.mode == "Image":
155
    ss.img_query = st.file_uploader("Upload image query", type=["png", "jpg", "jpeg"])
156
  else:
157
    ss.text_query = st.text_input("Enter text query (Example: a person with sunglasses and earphones)")
158
    ss.img_query = st.file_uploader("Upload image query", type=["png", "jpg", "jpeg"])
159
160
  if st.button("Submit"):
161
    if ss.mode == "Text":
162
      if ss.text_query is not None:
163
        text_search(ss.text_query)
164
    elif ss.mode == "Image":
165
      if ss.img_query is not None:
166
        img_search(ss.img_query)
167
    else:
168
      if ss.text_query is not None and ss.img_query is not None:
169
        text_and_img_search(ss.text_query, ss.img_query)
170
171
st.markdown("By [David Chuan-En Lin](https://chuanenlin.com) at Carnegie Mellon University. The querying is powered by [OpenAI's CLIP neural network](https://openai.com/blog/clip) and the interface was built with [Streamlit](https://streamlit.io). Many aspects of this project are based on the kind work of [Vladimir Haltakov](https://haltakov.net) and [Haofan Wang](https://haofanwang.github.io).")