clifs / app.py
ncoop57
Update app layout and change to streaming only max 2mins of video instead of downloading whole thing using ffmpeg
c40e192
raw
history blame
2.98 kB
import ffmpeg
import torch
import youtube_dl
import numpy as np
import streamlit as st
from sentence_transformers import SentenceTransformer, util, models
from clip import CLIPModel
from PIL import Image
@st.cache(allow_output_mutation=True, max_entries=1)
def get_model():
clip = CLIPModel()
model = SentenceTransformer(modules=[clip]).to(dtype=torch.float32, device=torch.device('cpu'))
return model
def get_embedding(model, query, video):
text_emb = model.encode(query, device='cpu')
# Encode an image:
images = []
for img in video:
images.append(Image.fromarray(img))
img_embs = model.encode(images, device='cpu')
return text_emb, img_embs
def find_frames(url, model, desc, top_k, text):
text.text("Processing video...")
probe = ffmpeg.probe(url)
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
out, _ = (
ffmpeg
.input(url, t=120)
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True)
)
video = (
np
.frombuffer(out, np.uint8)
.reshape([-1, height, width, 3])
)[::10]
txt_embd, img_embds = get_embedding(model, desc, video)
cos_scores = np.array(util.cos_sim(txt_embd, img_embds))
ids = np.argsort(cos_scores)[0][-top_k:]
imgs = [Image.fromarray(video[i]) for i in ids]
text.empty()
st.image(imgs)
def main_page(model):
st.title("Introducing Youtube CLIFS")
def clifs_page(model):
st.title("CLIFS")
st.sidebar.markdown("### Controls:")
top_k = st.sidebar.slider(
"Top K",
min_value=1,
max_value=5,
step=1,
)
desc = st.sidebar.text_input(
"Search Description",
value="Two white puppies",
help="Text description of what you want to find in the video",
)
url = st.sidebar.text_input(
"Youtube Video URL",
value='https://youtu.be/I3AaW9ZevIU',
help="Youtube video you'd like to search through",
)
submit_button = st.sidebar.button("Search")
if submit_button:
text = st.text("Downloading video...")
hook = lambda d: my_hook(d, )
ydl_opts = {"format": "mp4[height=360]"}
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
info_dict = ydl.extract_info(url, download=False)
video_url = info_dict.get("url", None)
find_frames(video_url, model, desc, top_k, text)
print(video_url)
# ydl.download([url])
PAGES = {
"Home": main_page,
"CLIFS": clifs_page
}
def run():
st.set_page_config(page_title="Youtube CLIFS")
# main body
model = get_model()
st.sidebar.title('Navigation')
selection = st.sidebar.radio("Go to", list(PAGES.keys()))
page = PAGES[selection](model)
if __name__ == "__main__":
run()