David Chuan-En Lin commited on
Commit
a9cbf7c
β€’
1 Parent(s): 68d2ea9
Files changed (4) hide show
  1. README.md +8 -6
  2. SessionState.py +70 -0
  3. requirements.txt +6 -0
  4. whichframe.py +129 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Which Frame
3
- emoji: πŸ¦€
4
  colorFrom: pink
5
- colorTo: blue
6
  sdk: streamlit
7
- sdk_version: 1.10.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
2
+ title: Which Frame?
3
+ emoji: πŸ”
4
  colorFrom: pink
5
+ colorTo: purple
6
  sdk: streamlit
7
+ sdk_version: 1.1.0
8
+ app_file: whichframe.py
9
  pinned: false
10
  ---
11
 
12
+ # Which Frame?
13
+
14
+ **Semantic** video search. For example, try a natural language search query like "a person with sunglasses".
SessionState.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit.report_thread as ReportThread
2
+ from streamlit.server.server import Server
3
+
4
+
5
+ class SessionState():
6
+ """SessionState: Add per-session state to Streamlit."""
7
+ def __init__(self, **kwargs):
8
+ """A new SessionState object.
9
+
10
+ Parameters
11
+ ----------
12
+ **kwargs : any
13
+ Default values for the session state.
14
+
15
+ Example
16
+ -------
17
+ >>> session_state = SessionState(user_name='', favorite_color='black')
18
+ >>> session_state.user_name = 'Mary'
19
+ ''
20
+ >>> session_state.favorite_color
21
+ 'black'
22
+
23
+ """
24
+ for key, val in kwargs.items():
25
+ setattr(self, key, val)
26
+
27
+
28
+ def get(**kwargs):
29
+ """Gets a SessionState object for the current session.
30
+
31
+ Creates a new object if necessary.
32
+
33
+ Parameters
34
+ ----------
35
+ **kwargs : any
36
+ Default values you want to add to the session state, if we're creating a
37
+ new one.
38
+
39
+ Example
40
+ -------
41
+ >>> session_state = get(user_name='', favorite_color='black')
42
+ >>> session_state.user_name
43
+ ''
44
+ >>> session_state.user_name = 'Mary'
45
+ >>> session_state.favorite_color
46
+ 'black'
47
+
48
+ Since you set user_name above, next time your script runs this will be the
49
+ result:
50
+ >>> session_state = get(user_name='', favorite_color='black')
51
+ >>> session_state.user_name
52
+ 'Mary'
53
+
54
+ """
55
+ # Hack to get the session object from Streamlit.
56
+
57
+ session_id = ReportThread.get_report_ctx().session_id
58
+ session_info = Server.get_current()._get_session_info(session_id)
59
+
60
+ if session_info is None:
61
+ raise RuntimeError('Could not get Streamlit session object.')
62
+
63
+ this_session = session_info.session
64
+
65
+ # Got the session object! Now let's attach some state into it.
66
+
67
+ if not hasattr(this_session, '_custom_session_state'):
68
+ this_session._custom_session_state = SessionState(**kwargs)
69
+
70
+ return this_session._custom_session_state
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ Pillow
2
+ pytube
3
+ opencv-python-headless
4
+ torch
5
+ git+https://github.com/openai/CLIP.git
6
+ humanfriendly
whichframe.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pytube import YouTube
3
+ from pytube import extract
4
+ import cv2
5
+ from PIL import Image
6
+ import clip as openai_clip
7
+ import torch
8
+ import math
9
+ import SessionState
10
+ from humanfriendly import format_timespan
11
+
12
+ def fetch_video(url):
13
+ yt = YouTube(url)
14
+ streams = yt.streams.filter(adaptive=True, subtype="mp4", resolution="360p", only_video=True)
15
+ length = yt.length
16
+ if length >= 300:
17
+ st.error("Please find a YouTube video shorter than 5 minutes. Sorry about this, the server capacity is limited for the time being.")
18
+ st.stop()
19
+ video = streams[0]
20
+ return video, video.url
21
+
22
+ @st.cache()
23
+ def extract_frames(video):
24
+ frames = []
25
+ capture = cv2.VideoCapture(video)
26
+ fps = capture.get(cv2.CAP_PROP_FPS)
27
+ current_frame = 0
28
+ while capture.isOpened():
29
+ ret, frame = capture.read()
30
+ if ret == True:
31
+ frames.append(Image.fromarray(frame[:, :, ::-1]))
32
+ else:
33
+ break
34
+ current_frame += N
35
+ capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
36
+ return frames, fps
37
+
38
+ @st.cache()
39
+ def encode_frames(video_frames):
40
+ batch_size = 256
41
+ batches = math.ceil(len(video_frames) / batch_size)
42
+ video_features = torch.empty([0, 512], dtype=torch.float16).to(device)
43
+ for i in range(batches):
44
+ batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
45
+ batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
46
+ with torch.no_grad():
47
+ batch_features = model.encode_image(batch_preprocessed)
48
+ batch_features /= batch_features.norm(dim=-1, keepdim=True)
49
+ video_features = torch.cat((video_features, batch_features))
50
+ return video_features
51
+
52
+ def img_to_bytes(img):
53
+ img_byte_arr = io.BytesIO()
54
+ img.save(img_byte_arr, format='JPEG')
55
+ img_byte_arr = img_byte_arr.getvalue()
56
+ return img_byte_arr
57
+
58
+ def display_results(best_photo_idx):
59
+ st.markdown("**Top-5 matching results**")
60
+ result_arr = []
61
+ for frame_id in best_photo_idx:
62
+ result = ss.video_frames[frame_id]
63
+ st.image(result)
64
+ seconds = round(frame_id.cpu().numpy()[0] * N / ss.fps)
65
+ result_arr.append(seconds)
66
+ time = format_timespan(seconds)
67
+ if ss.input == "file":
68
+ st.write("Seen at " + str(time) + " into the video.")
69
+ else:
70
+ st.markdown("Seen at [" + str(time) + "](" + url + "&t=" + str(seconds) + "s) into the video.")
71
+ return result_arr
72
+
73
+ def text_search(search_query, display_results_count=5):
74
+ with torch.no_grad():
75
+ text_features = model.encode_text(openai_clip.tokenize(search_query).to(device))
76
+ text_features /= text_features.norm(dim=-1, keepdim=True)
77
+ similarities = (100.0 * ss.video_features @ text_features.T)
78
+ values, best_photo_idx = similarities.topk(display_results_count, dim=0)
79
+ result_arr = display_results(best_photo_idx)
80
+ return result_arr
81
+
82
+ st.set_page_config(page_title="Which Frame?", page_icon = "πŸ”", layout = "centered", initial_sidebar_state = "collapsed")
83
+
84
+ hide_streamlit_style = """
85
+ <style>
86
+ #MainMenu {visibility: hidden;}
87
+ footer {visibility: hidden;}
88
+ * {font-family: Avenir;}
89
+ .css-gma2qf {display: flex; justify-content: center; font-size: 42px; font-weight: bold;}
90
+ a:link {text-decoration: none;}
91
+ a:hover {text-decoration: none;}
92
+ .st-ba {font-family: Avenir;}
93
+ </style>
94
+ """
95
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
96
+
97
+ ss = SessionState.get(url=None, id=None, input=None, file_name=None, video=None, video_name=None, video_frames=None, video_features=None, fps=None, mode=None, query=None, progress=1)
98
+
99
+ st.title("Which Frame?")
100
+ st.markdown("Search a video **semantically**. For example: Which frame has a person with sunglasses and earphones?")
101
+ url = st.text_input("Link to a YouTube video (Example: https://www.youtube.com/watch?v=sxaTnm_4YMY)")
102
+
103
+ N = 30
104
+
105
+ device = "cuda" if torch.cuda.is_available() else "cpu"
106
+ model, preprocess = openai_clip.load("ViT-B/32", device=device)
107
+
108
+ if st.button("Process video (this may take a while)"):
109
+ ss.progress = 1
110
+ ss.video_start_time = 0
111
+ if url:
112
+ ss.input = "link"
113
+ ss.video, ss.video_name = fetch_video(url)
114
+ ss.id = extract.video_id(url)
115
+ ss.url = "https://www.youtube.com/watch?v=" + ss.id
116
+ else:
117
+ st.error("Please upload a video or link to a valid YouTube video")
118
+ st.stop()
119
+ ss.video_frames, ss.fps = extract_frames(ss.video_name)
120
+ ss.video_features = encode_frames(ss.video_frames)
121
+ st.video(ss.url)
122
+ ss.progress = 2
123
+
124
+ if ss.progress == 2:
125
+ ss.text_query = st.text_input("Enter search query (Example: a person with sunglasses and earphones)")
126
+
127
+ if st.button("Submit"):
128
+ if ss.text_query is not None:
129
+ text_search(ss.text_query)