David Chuan-En Lin commited on
Commit
7415fc4
โ€ข
1 Parent(s): 81e09e6
README.md CHANGED
@@ -1,37 +1,77 @@
1
  ---
2
- title: Whichframe
3
- emoji: ๐Ÿ 
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: streamlit
7
- app_file: app.py
8
- pinned: false
9
  ---
10
 
11
- # Configuration
12
 
13
- `title`: _string_
14
- Display title for the Space
15
 
16
- `emoji`: _string_
17
- Space emoji (emoji-only character allowed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- `colorFrom`: _string_
20
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
 
22
- `colorTo`: _string_
23
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
 
25
- `sdk`: _string_
26
- Can be either `gradio` or `streamlit`
27
 
28
- `sdk_version` : _string_
29
- Only applicable for `streamlit` SDK.
30
- See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
 
32
- `app_file`: _string_
33
- Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
- Path is relative to the root of the repository.
35
 
36
- `pinned`: _boolean_
37
- Whether the Space stays on top of your list.
1
  ---
2
+ title: Which Frame?
3
+ emoji: ๐Ÿ”
4
+ colorFrom: purple
5
+ colorTo: purple
6
  sdk: streamlit
7
+ app_file: whichframe.py
 
8
  ---
9
 
10
+ # Which Frame?
11
 
12
+ Search a video **semantically** with AI. For example, try a natural language search query like "a person with sunglasses". You can also search with images like Google's reverse image search and also a combined text + image. The underlying querying is powered by OpenAIโ€™s CLIP neural network for "zero-shot" image classification.
 
13
 
14
+ ---
15
+
16
+ ## Try it out!
17
+
18
+ http://whichframe.chuanenlin.com
19
+
20
+ ---
21
+
22
+ ## Setting up
23
+
24
+ 1. Clone the repository.
25
+
26
+ ```python
27
+ git clone https://github.com/chuanenlin/whichframe.git
28
+ cd whichframe
29
+ ```
30
+
31
+ 2. Install package dependencies.
32
+
33
+ ```python
34
+ pip install -r requirements.txt
35
+ ```
36
+
37
+ 3. Run the app.
38
+
39
+ ```python
40
+ streamlit run whichframe.py
41
+ ```
42
+
43
+ ---
44
+
45
+ ## Examples
46
+
47
+ ### ๐Ÿ”ค Text Search
48
+
49
+ #### Query
50
+
51
+ "three red cars side by side"
52
+
53
+ #### Result
54
+
55
+ ![three-red-cars-side-by-side](examples/three-red-cars-side-by-side.jpeg)
56
+
57
+ ### ๐ŸŒ… Image Search
58
+
59
+ #### Query
60
+
61
+ ![police-car-query](examples/helicopter-query.jpeg)
62
+
63
+ #### Result
64
+
65
+ ![police-car-result](examples/helicopter-result.jpeg)
66
 
67
+ ### ๐Ÿ”ค Text + ๐ŸŒ… Image Search
 
68
 
69
+ #### Query
 
70
 
71
+ "a red subaru" +
 
72
 
73
+ ![police-car-query](examples/police-car-query.jpeg)
 
 
74
 
75
+ #### Result
 
 
76
 
77
+ ![subaru-and-police-car-result](examples/subaru-and-police-car-result.jpeg)
 
SessionState.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hack to add per-session state to Streamlit.
2
+
3
+ Works for Streamlit >= v0.65
4
+
5
+ Usage
6
+ -----
7
+
8
+ >>> import SessionState
9
+ >>>
10
+ >>> session_state = SessionState.get(user_name='', favorite_color='black')
11
+ >>> session_state.user_name
12
+ ''
13
+ >>> session_state.user_name = 'Mary'
14
+ >>> session_state.favorite_color
15
+ 'black'
16
+
17
+ Since you set user_name above, next time your script runs this will be the
18
+ result:
19
+ >>> session_state = get(user_name='', favorite_color='black')
20
+ >>> session_state.user_name
21
+ 'Mary'
22
+
23
+ """
24
+
25
+ import streamlit.report_thread as ReportThread
26
+ from streamlit.server.server import Server
27
+
28
+
29
+ class SessionState():
30
+ """SessionState: Add per-session state to Streamlit."""
31
+ def __init__(self, **kwargs):
32
+ """A new SessionState object.
33
+
34
+ Parameters
35
+ ----------
36
+ **kwargs : any
37
+ Default values for the session state.
38
+
39
+ Example
40
+ -------
41
+ >>> session_state = SessionState(user_name='', favorite_color='black')
42
+ >>> session_state.user_name = 'Mary'
43
+ ''
44
+ >>> session_state.favorite_color
45
+ 'black'
46
+
47
+ """
48
+ for key, val in kwargs.items():
49
+ setattr(self, key, val)
50
+
51
+
52
+ def get(**kwargs):
53
+ """Gets a SessionState object for the current session.
54
+
55
+ Creates a new object if necessary.
56
+
57
+ Parameters
58
+ ----------
59
+ **kwargs : any
60
+ Default values you want to add to the session state, if we're creating a
61
+ new one.
62
+
63
+ Example
64
+ -------
65
+ >>> session_state = get(user_name='', favorite_color='black')
66
+ >>> session_state.user_name
67
+ ''
68
+ >>> session_state.user_name = 'Mary'
69
+ >>> session_state.favorite_color
70
+ 'black'
71
+
72
+ Since you set user_name above, next time your script runs this will be the
73
+ result:
74
+ >>> session_state = get(user_name='', favorite_color='black')
75
+ >>> session_state.user_name
76
+ 'Mary'
77
+
78
+ """
79
+ # Hack to get the session object from Streamlit.
80
+
81
+ session_id = ReportThread.get_report_ctx().session_id
82
+ session_info = Server.get_current()._get_session_info(session_id)
83
+
84
+ if session_info is None:
85
+ raise RuntimeError('Could not get Streamlit session object.')
86
+
87
+ this_session = session_info.session
88
+
89
+ # Got the session object! Now let's attach some state into it.
90
+
91
+ if not hasattr(this_session, '_custom_session_state'):
92
+ this_session._custom_session_state = SessionState(**kwargs)
93
+
94
+ return this_session._custom_session_state
examples/helicopter-query.jpeg ADDED
Binary file
examples/helicopter-result.jpeg ADDED
Binary file
examples/police-car-query.jpeg ADDED
Binary file
examples/subaru-and-police-car-result.jpeg ADDED
Binary file
examples/subaru.jpeg ADDED
Binary file
examples/three-red-cars-side-by-side.jpeg ADDED
Binary file
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ Pillow
3
+ streamlit
4
+ pytube
5
+ opencv-python-headless
6
+ torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
7
+ git+https://github.com/openai/CLIP.git
8
+ humanfriendly
whichframe.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pytube import YouTube
3
+ from pytube import extract
4
+ import cv2
5
+ from PIL import Image
6
+ import clip as openai_clip
7
+ import torch
8
+ import math
9
+ import numpy as np
10
+ import SessionState
11
+ import tempfile
12
+ from humanfriendly import format_timespan
13
+ import json
14
+ import sys
15
+ from random import randrange
16
+ import requests
17
+
18
+ def fetch_video(url):
19
+ yt = YouTube(url)
20
+ streams = yt.streams.filter(adaptive=True, subtype="mp4", resolution="360p", only_video=True)
21
+ length = yt.length
22
+ video = streams[0]
23
+ return video, video.url
24
+
25
+ @st.cache()
26
+ def extract_frames(video):
27
+ frames = []
28
+ capture = cv2.VideoCapture(video)
29
+ fps = capture.get(cv2.CAP_PROP_FPS)
30
+ current_frame = 0
31
+ while capture.isOpened():
32
+ ret, frame = capture.read()
33
+ if ret == True:
34
+ frames.append(Image.fromarray(frame[:, :, ::-1]))
35
+ else:
36
+ break
37
+ current_frame += N
38
+ capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
39
+ return frames, fps
40
+
41
+ @st.cache()
42
+ def encode_frames(video_frames):
43
+ batch_size = 256
44
+ batches = math.ceil(len(video_frames) / batch_size)
45
+ video_features = torch.empty([0, 512], dtype=torch.float16).to(device)
46
+ for i in range(batches):
47
+ batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
48
+ batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
49
+ with torch.no_grad():
50
+ batch_features = model.encode_image(batch_preprocessed)
51
+ batch_features /= batch_features.norm(dim=-1, keepdim=True)
52
+ video_features = torch.cat((video_features, batch_features))
53
+ return video_features
54
+
55
+ def img_to_bytes(img):
56
+ img_byte_arr = io.BytesIO()
57
+ img.save(img_byte_arr, format='JPEG')
58
+ img_byte_arr = img_byte_arr.getvalue()
59
+ return img_byte_arr
60
+
61
+ def display_results(best_photo_idx):
62
+ st.markdown("**Top-5 matching results**")
63
+ result_arr = []
64
+ for frame_id in best_photo_idx:
65
+ result = ss.video_frames[frame_id]
66
+ st.image(result)
67
+ seconds = round(frame_id.cpu().numpy()[0] * N / ss.fps)
68
+ result_arr.append(seconds)
69
+ time = format_timespan(seconds)
70
+ if ss.input == "file":
71
+ st.write("Seen at " + str(time) + " into the video.")
72
+ else:
73
+ st.markdown("Seen at [" + str(time) + "](" + url + "&t=" + str(seconds) + "s) into the video.")
74
+ return result_arr
75
+
76
+ def text_search(search_query, display_results_count=5):
77
+ with torch.no_grad():
78
+ text_features = model.encode_text(openai_clip.tokenize(search_query).to(device))
79
+ text_features /= text_features.norm(dim=-1, keepdim=True)
80
+ similarities = (100.0 * ss.video_features @ text_features.T)
81
+ values, best_photo_idx = similarities.topk(display_results_count, dim=0)
82
+ result_arr = display_results(best_photo_idx)
83
+ return result_arr
84
+
85
+ def img_search(search_query, display_results_count=5):
86
+ with torch.no_grad():
87
+ image_features = model.encode_image(preprocess(Image.open(search_query)).unsqueeze(0).to(device))
88
+ image_features /= image_features.norm(dim=-1, keepdim=True)
89
+ similarities = (100.0 * ss.video_features @ image_features.T)
90
+ values, best_photo_idx = similarities.topk(display_results_count, dim=0)
91
+ display_results(best_photo_idx)
92
+
93
+ def text_and_img_search(text_search_query, image_search_query, display_results_count=5):
94
+ with torch.no_grad():
95
+ image_features = model.encode_image(preprocess(Image.open(image_search_query)).unsqueeze(0).to(device))
96
+ image_features /= image_features.norm(dim=-1, keepdim=True)
97
+ text_features = model.encode_text(openai_clip.tokenize(text_search_query).to(device))
98
+ text_features /= text_features.norm(dim=-1, keepdim=True)
99
+ hybrid_features = image_features + text_features
100
+ similarities = (100.0 * ss.video_features @ hybrid_features.T)
101
+ values, best_photo_idx = similarities.topk(display_results_count, dim=0)
102
+ result_arr = display_results(best_photo_idx)
103
+ return result_arr
104
+
105
+ st.set_page_config(page_title="Which Frame?", page_icon = "๐Ÿ”", layout = "centered", initial_sidebar_state = "collapsed")
106
+
107
+ hide_streamlit_style = """
108
+ <style>
109
+ #MainMenu {visibility: hidden;}
110
+ footer {visibility: hidden;}
111
+ * {font-family: Avenir;}
112
+ .css-gma2qf {display: flex; justify-content: center; font-size: 42px; font-weight: bold;}
113
+ a:link {text-decoration: none;}
114
+ a:hover {text-decoration: none;}
115
+ .st-ba {font-family: Avenir;}
116
+ </style>
117
+ """
118
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
119
+
120
+ ss = SessionState.get(url=None, id=None, input=None, file_name=None, video=None, video_name=None, video_frames=None, video_features=None, fps=None, mode=None, query=None, progress=1)
121
+
122
+ st.title("Which Frame?")
123
+ st.markdown("Search a video **semantically**. Which frame has a person with sunglasses and earphones? Try searching with **text**, **image**, or a combined **text + image**.")
124
+ url = st.text_input("Link to a YouTube video (Example: https://www.youtube.com/watch?v=sxaTnm_4YMY)")
125
+
126
+ N = 30
127
+
128
+ device = "cuda" if torch.cuda.is_available() else "cpu"
129
+ model, preprocess = openai_clip.load("ViT-B/32", device=device)
130
+
131
+ if st.button("Process video (this may take a while)"):
132
+ ss.progress = 1
133
+ ss.video_start_time = 0
134
+ if url:
135
+ ss.input = "link"
136
+ ss.video, ss.video_name = fetch_video(url)
137
+ ss.id = extract.video_id(url)
138
+ ss.url = "https://www.youtube.com/watch?v=" + ss.id
139
+ else:
140
+ st.error("Please upload a video or link to a valid YouTube video")
141
+ st.stop()
142
+ ss.video_frames, ss.fps = extract_frames(ss.video_name)
143
+ ss.video_features = encode_frames(ss.video_frames)
144
+ st.video(ss.url)
145
+ ss.progress = 2
146
+
147
+ if ss.progress == 2:
148
+ ss.mode = st.selectbox("Select a search method (text, image, or text + image)",("Text", "Image", "Text + Image"))
149
+ if ss.mode == "Text":
150
+ ss.text_query = st.text_input("Enter text query (Example: a person with sunglasses and earphones)")
151
+ elif ss.mode == "Image":
152
+ ss.img_query = st.file_uploader("Upload image query", type=["png", "jpg", "jpeg"])
153
+ else:
154
+ ss.text_query = st.text_input("Enter text query (Example: a person with sunglasses and earphones)")
155
+ ss.img_query = st.file_uploader("Upload image query", type=["png", "jpg", "jpeg"])
156
+
157
+ if st.button("Submit"):
158
+ if ss.mode == "Text":
159
+ if ss.text_query is not None:
160
+ text_search(ss.text_query)
161
+ elif ss.mode == "Image":
162
+ if ss.img_query is not None:
163
+ img_search(ss.img_query)
164
+ else:
165
+ if ss.text_query is not None and ss.img_query is not None:
166
+ text_and_img_search(ss.text_query, ss.img_query)
167
+
168
+ st.markdown("By [David Chuan-En Lin](https://chuanenlin.com) at Carnegie Mellon University. The querying is powered by [OpenAI's CLIP neural network](https://openai.com/blog/clip) and the interface was built with [Streamlit](https://streamlit.io). Many aspects of this project are based on the kind work of [Vladimir Haltakov](https://haltakov.net) and [Haofan Wang](https://haofanwang.github.io).")