clifs / app.py
ncoop57
Add generating GIF from returned frames
81abbd1
import ffmpeg
import os
import torch
import uuid
import youtube_dl
import numpy as np
import streamlit as st
from sentence_transformers import SentenceTransformer, util, models
from clip import CLIPModel
from PIL import Image
@st.cache(allow_output_mutation=True, max_entries=1)
def get_model():
txt_model = SentenceTransformer('clip-ViT-B-32-multilingual-v1').to(dtype=torch.float32, device=torch.device('cpu'))
clip = CLIPModel()
vis_model = SentenceTransformer(modules=[clip]).to(dtype=torch.float32, device=torch.device('cpu'))
return txt_model, vis_model
def get_embedding(txt_model, vis_model, query, video):
text_emb = txt_model.encode(query, device='cpu')
# Encode an image:
images = []
for img in video:
images.append(Image.fromarray(img))
img_embs = vis_model.encode(images, device='cpu')
return text_emb, img_embs
def find_frames(url, txt_model, vis_model, desc, seconds, top_k):
text = st.text("Downloading video (Descargando video)...")
# gif from https://giphy.com/gifs/alan-DfSXiR60W9MVq
gif_runner = st.image("./loading.gif")
probe = ffmpeg.probe(url)
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
out, _ = (
ffmpeg
.input(url, t=seconds)
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True)
)
text.text("Processing video (Procesando video)...")
video = (
np
.frombuffer(out, np.uint8)
.reshape([-1, height, width, 3])
)[::10]
txt_embd, img_embds = get_embedding(txt_model, vis_model, desc, video)
cos_scores = np.array(util.cos_sim(txt_embd, img_embds))
ids = np.argsort(cos_scores)[0][-top_k:]
imgs = [Image.fromarray(video[i]) for i in sorted(ids)]
# from: https://stackoverflow.com/a/57751793/5768407
fname = uuid.uuid4().hex
imgs[0].save(fp=f"./{fname}.gif", format='GIF', append_images=imgs[1:],
save_all=True, duration=200, loop=0)
gif_runner.empty()
text.empty()
st.image(f"./{fname}.gif")
# remove the gif from file so we don't build up a bunch of files
os.remove(f"./{fname}.gif")
st.image(imgs)
with open("HOME.md", "r") as f:
HOME_PAGE = f.read()
with open("INICIO.md", "r") as f:
INICIO_PAGINA = f.read()
def main_page(txt_model, vis_model):
st.title("Introducing Youtube CLIFS")
st.markdown(HOME_PAGE)
def inicio_pagina(txt_model, vis_model):
st.title("Presentando Youtube CLIFS")
st.markdown(INICIO_PAGINA)
def clifs_page(txt_model, vis_model):
st.title("CLIFS")
st.sidebar.markdown("### Controls (Controles):")
seconds = st.sidebar.slider(
"How many seconds of video to consider? (¿Cuántos segundos de video considerar?)",
min_value=10,
max_value=120,
value=60,
step=1,
)
top_k = st.sidebar.slider(
"Top K",
min_value=1,
max_value=20,
value=10,
step=1,
)
desc = st.sidebar.text_input(
"Search Query (Consulta de Búsqueda)",
value="Pancake in the shape of an otter",
help="Text description of what you want to find in the video (Descripción de texto de que desea encontrar en el video)",
)
url = st.sidebar.text_input(
"Youtube Video URL (URL del Video de Youtube)",
value='https://youtu.be/xUv6XgPwGaQ',
help="Youtube video you want to search (Video de Youtube que desea búscar)",
)
quality = st.sidebar.radio(
"Quality of the Video (Calidad del Video)",
[144, 240, 360, 480],
help="Quality of the video to download. Higher quality takes more time (Calidad del video para descargar. Calidad más alta toma más tiempo)",
)
submit_button = st.sidebar.button("Search (Búscar)")
if submit_button:
ydl_opts = {"format": f"mp4[height={quality}]"}
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
info_dict = ydl.extract_info(url, download=False)
video_url = info_dict.get("url", None)
find_frames(video_url, txt_model, vis_model, desc, seconds, top_k)
PAGES = {
"CLIFS": clifs_page,
"Home": main_page,
"Inicio": inicio_pagina,
}
def run():
st.set_page_config(page_title="Youtube CLIFS")
# main body
txt_model, vis_model = get_model()
st.sidebar.title("Navigation (Navegación)")
selection = st.sidebar.radio("Go to (Ir a)", list(PAGES.keys()))
page = PAGES[selection](txt_model, vis_model)
if __name__ == "__main__":
run()