import ffmpeg import os import torch import uuid import youtube_dl import numpy as np import streamlit as st from sentence_transformers import SentenceTransformer, util, models from clip import CLIPModel from PIL import Image @st.cache(allow_output_mutation=True, max_entries=1) def get_model(): txt_model = SentenceTransformer('clip-ViT-B-32-multilingual-v1').to(dtype=torch.float32, device=torch.device('cpu')) clip = CLIPModel() vis_model = SentenceTransformer(modules=[clip]).to(dtype=torch.float32, device=torch.device('cpu')) return txt_model, vis_model def get_embedding(txt_model, vis_model, query, video): text_emb = txt_model.encode(query, device='cpu') # Encode an image: images = [] for img in video: images.append(Image.fromarray(img)) img_embs = vis_model.encode(images, device='cpu') return text_emb, img_embs def find_frames(url, txt_model, vis_model, desc, seconds, top_k): text = st.text("Downloading video (Descargando video)...") # gif from https://giphy.com/gifs/alan-DfSXiR60W9MVq gif_runner = st.image("./loading.gif") probe = ffmpeg.probe(url) video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) width = int(video_stream['width']) height = int(video_stream['height']) out, _ = ( ffmpeg .input(url, t=seconds) .output('pipe:', format='rawvideo', pix_fmt='rgb24') .run(capture_stdout=True) ) text.text("Processing video (Procesando video)...") video = ( np .frombuffer(out, np.uint8) .reshape([-1, height, width, 3]) )[::10] txt_embd, img_embds = get_embedding(txt_model, vis_model, desc, video) cos_scores = np.array(util.cos_sim(txt_embd, img_embds)) ids = np.argsort(cos_scores)[0][-top_k:] imgs = [Image.fromarray(video[i]) for i in sorted(ids)] # from: https://stackoverflow.com/a/57751793/5768407 fname = uuid.uuid4().hex imgs[0].save(fp=f"./{fname}.gif", format='GIF', append_images=imgs[1:], save_all=True, duration=200, loop=0) gif_runner.empty() text.empty() st.image(f"./{fname}.gif") # remove the gif from file so we don't build up a bunch of files os.remove(f"./{fname}.gif") st.image(imgs) with open("HOME.md", "r") as f: HOME_PAGE = f.read() with open("INICIO.md", "r") as f: INICIO_PAGINA = f.read() def main_page(txt_model, vis_model): st.title("Introducing Youtube CLIFS") st.markdown(HOME_PAGE) def inicio_pagina(txt_model, vis_model): st.title("Presentando Youtube CLIFS") st.markdown(INICIO_PAGINA) def clifs_page(txt_model, vis_model): st.title("CLIFS") st.sidebar.markdown("### Controls (Controles):") seconds = st.sidebar.slider( "How many seconds of video to consider? (¿Cuántos segundos de video considerar?)", min_value=10, max_value=120, value=60, step=1, ) top_k = st.sidebar.slider( "Top K", min_value=1, max_value=20, value=10, step=1, ) desc = st.sidebar.text_input( "Search Query (Consulta de Búsqueda)", value="Pancake in the shape of an otter", help="Text description of what you want to find in the video (Descripción de texto de que desea encontrar en el video)", ) url = st.sidebar.text_input( "Youtube Video URL (URL del Video de Youtube)", value='https://youtu.be/xUv6XgPwGaQ', help="Youtube video you want to search (Video de Youtube que desea búscar)", ) quality = st.sidebar.radio( "Quality of the Video (Calidad del Video)", [144, 240, 360, 480], help="Quality of the video to download. Higher quality takes more time (Calidad del video para descargar. Calidad más alta toma más tiempo)", ) submit_button = st.sidebar.button("Search (Búscar)") if submit_button: ydl_opts = {"format": f"mp4[height={quality}]"} with youtube_dl.YoutubeDL(ydl_opts) as ydl: info_dict = ydl.extract_info(url, download=False) video_url = info_dict.get("url", None) find_frames(video_url, txt_model, vis_model, desc, seconds, top_k) PAGES = { "CLIFS": clifs_page, "Home": main_page, "Inicio": inicio_pagina, } def run(): st.set_page_config(page_title="Youtube CLIFS") # main body txt_model, vis_model = get_model() st.sidebar.title("Navigation (Navegación)") selection = st.sidebar.radio("Go to (Ir a)", list(PAGES.keys())) page = PAGES[selection](txt_model, vis_model) if __name__ == "__main__": run()