import streamlit as st # from pytube import YouTube # from pytube import extract import cv2 from PIL import Image import clip as openai_clip import torch import math import numpy as np import tempfile # from humanfriendly import format_timespan import json import sys from random import randrange import logging # from pyunsplash import PyUnsplash import requests import io from io import BytesIO import base64 import altair as alt from streamlit_vega_lite import altair_component import pandas as pd from datetime import timedelta import math from decord import VideoReader, cpu, gpu from moviepy.video.io.VideoFileClip import VideoFileClip from moviepy.audio.io.AudioFileClip import AudioFileClip from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip from moviepy.editor import * import glob # @st.cache(show_spinner=False) def load_model(): device = 'cuda' if torch.cuda.is_available() else 'cpu' model, preprocess = openai_clip.load('ViT-B/32', device=device) if 'model' not in st.session_state: st.session_state.model = model st.session_state.preprocess = preprocess st.session_state.device = device st.session_state.model = model st.session_state.preprocess = preprocess st.session_state.device = device def fetch_video(url): yt = YouTube(url) streams = yt.streams.filter(adaptive=True, subtype='mp4', resolution='360p', only_video=True) length = yt.length if length >= 300: st.error('Please find a YouTube video shorter than 5 minutes. Sorry about this, the server capacity is limited for the time being.') st.stop() video = streams[0] return video, video.url # @st.cache() # def extract_frames(video): # frames = [] # capture = cv2.VideoCapture(video) # fps = capture.get(cv2.CAP_PROP_FPS) # current_frame = 0 # while capture.isOpened(): # ret, frame = capture.read() # if ret == True: # frames.append(Image.fromarray(frame[:, :, ::-1])) # else: # break # current_frame += fps # capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame) # # print(f'Frames extracted: {len(frames)}') # return frames, fps # @st.cache() def video_to_frames(video): vr = VideoReader(video) frames = [] frame_count = len(vr) fps = vr.get_avg_fps() for i in range(0, frame_count, round(fps)): # for i in range(0, frame_count): frame = vr[i].asnumpy() y_dim = frame.shape[0] x_dim = frame.shape[1] frames.append(Image.fromarray(frame)) return frames, fps, x_dim, y_dim def video_to_info(video): vr = VideoReader(video) frames = [] frame_count = len(vr) fps = vr.get_avg_fps() frame = vr[0].asnumpy() y_dim = frame.shape[0] x_dim = frame.shape[1] return fps, x_dim, y_dim # @st.cache() def encode_frames(video_frames): batch_size = 256 batches = math.ceil(len(video_frames) / batch_size) video_features = torch.empty([0, 512], dtype=torch.float16).to(st.session_state.device) for i in range(batches): batch_frames = video_frames[i*batch_size : (i+1)*batch_size] batch_preprocessed = torch.stack([st.session_state.preprocess(frame) for frame in batch_frames]).to(st.session_state.device) with torch.no_grad(): batch_features = st.session_state.model.encode_image(batch_preprocessed) batch_features /= batch_features.norm(dim=-1, keepdim=True) video_features = torch.cat((video_features, batch_features)) # print(f'Features: {video_features.shape}') return video_features def classify_activity(video_features, activities_list): text = torch.cat([openai_clip.tokenize( f'{activity}') for activity in activities_list]).to(st.session_state.device) with torch.no_grad(): text_features = st.session_state.model.encode_text(text) text_features /= text_features.norm(dim=-1, keepdim=True) logit_scale = st.session_state.model.logit_scale.exp() video_features = torch.from_numpy(video_features) similarities = (logit_scale * video_features @ text_features.t()).softmax(dim=-1) probs, word_idxs = similarities[0].topk(5) primary_activity = [] for prob, word_idx in zip(probs, word_idxs): primary_activity.append(activities_list[word_idx]) # primary_activity = activities_list[word_idx] return primary_activity def encode_photos(photos): batch_size = 256 batches = math.ceil(len(photos) / batch_size) video_features = torch.empty([0, 512], dtype=torch.float16).to(st.session_state.device) for i in range(batches): batch_frames = photos[i*batch_size : (i+1)*batch_size] batch_preprocessed = torch.stack([st.session_state.preprocess(Image.open(frame)) for frame in batch_frames]).to(st.session_state.device) with torch.no_grad(): batch_features = st.session_state.model.encode_image(batch_preprocessed) batch_features /= batch_features.norm(dim=-1, keepdim=True) video_features = torch.cat((video_features, batch_features)) # print(f'Features: {video_features.shape}') return video_features def img_to_bytes(img): img_byte_arr = io.BytesIO() img.save(img_byte_arr, format='JPEG') img_byte_arr = img_byte_arr.getvalue() return img_byte_arr def normalize(vector): return (vector - np.min(vector)) / (np.max(vector) - np.min(vector)) def format_img(img): size = 150, 150 # img = Image.fromarray(img) img.thumbnail(size, Image.Resampling.LANCZOS) output = io.BytesIO() img.save(output, format='PNG') encoded_string = f'data:image/png;base64,{base64.b64encode(output.getvalue()).decode()}' return encoded_string def get_photos(keyword): photo_collection = [] for filename in glob.glob(f'photos/{st.session_state.domain.lower()}/*.jpeg')[:1]: photo = Image.open(filename) photo_collection.append(photo) return photo_collection # # api_key = 'hzcKZ0e4we95wSd8_ip2zTB3m2DrOMWehAxrYjqjwg0' # api_key = 'fZ1nE7Y4NC-iYGmqgv-WuyM8m9p0LroCdAOZOR6tyho' # unsplash_search = PyUnsplash(api_key=api_key) # logging.getLogger('pyunsplash').setLevel(logging.DEBUG) # search = unsplash_search.search(type_='photos', query=keyword) # per_page # photo_collection = [] # # st.markdown(f'**Unsplash photos for `{keyword}`**') # for result in search.entries: # photo_url = result.link_download # response = requests.get(photo_url) # photo = Image.open(BytesIO(response.content)) # # st.image(photo, width=200) # photo_collection.append(photo) # return photo_collection def display_results(best_photo_idx): st.markdown('**Top 10 highlights**') result_arr = [] for frame_id in best_photo_idx: result = st.session_state.video_frames[frame_id] st.image(result) return result_arr def make_df(similarities): similarities = similarities df = pd.DataFrame() df['keyword'] = [keyword] * len(similarities) df['x'] = [i for i, _ in enumerate(similarities)] df['y'] = normalize(np.power(similarities, 8)) df['image'] = [format_img(frame) for frame in st.session_state.video_frames] return df # @st.cache() def compute_scores(search_query, video_features, text_query, display_results_count=10): sum_photo = torch.zeros(1, 512) for photo in search_query: with torch.no_grad(): image_features = st.session_state.model.encode_image(st.session_state.preprocess(photo).unsqueeze(0).to(st.session_state.device)) image_features /= image_features.norm(dim=-1, keepdim=True) sum_photo += sum_photo + image_features avg_photo = sum_photo / len(search_query) video_features = torch.from_numpy(video_features) similarities = (100.0 * video_features @ avg_photo.T) # values, best_photo_idx = similarities.topk(display_results_count, dim=0) # display_results(best_photo_idx) return similarities.cpu().numpy() def avenir(): font = 'Avenir' return { 'config' : { 'title': {'font': font}, 'axis': { 'labelFont': font, 'titleFont': font } } } alt.themes.register('avenir', avenir) alt.themes.enable('avenir') # TODO: Make playhead scores and average according to keyword # TODO: Maximum interval selection # TODO: Interactive legend https://altair-viz.github.io/gallery/interactive_legend.html # TODO: Multi-line highlight https://altair-viz.github.io/gallery/multiline_highlight.html @st.cache def draw_chart(df, mode): if st.session_state.mode == 'Automatic': nearest = alt.selection(type='single', nearest=True, on='mouseover', empty='none') line = alt.Chart(df).mark_line().encode( x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')), y=alt.Y('y', axis=alt.Axis(labels=False, tickSize=0, title='')), # color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')), color=alt.value('#00C7BE'), # color=alt.Color('#9b59b6'), ) selectors = alt.Chart(df).mark_point().encode( x='x:Q', opacity=alt.value(0), ).add_selection( nearest ) rules = alt.Chart(df).mark_rule(color='black').encode( x='x:Q', ).transform_filter( nearest ) points = line.mark_point().encode( opacity=alt.condition(nearest, alt.value(1), alt.value(0)) ) text = line.mark_text(align='center', yOffset=-110, fontSize=16).encode( text=alt.condition(nearest, 'y:N', alt.value(' ')), color=alt.value('#000000'), # fontSize=30 ).transform_calculate(y=f'format(datum.y, ".2f")') image = line.mark_image(align='center', width=150, height=150, yOffset=-60).encode( url=alt.condition(nearest, 'image', alt.value(' ')) ) chart = alt.layer(line, selectors, points, rules, text, image) elif st.session_state.mode == 'brush': brush = alt.selection(type='interval', encodings=['x']) line = alt.Chart(df).mark_line().encode( # https://www.rdocumentation.org/packages/vegalite/versions/0.6.1/topics/mark_line x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')), y=alt.Y('y:Q', axis=alt.Axis(labels=False, tickSize=0, title='')), # color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')), color=alt.value('#00C7BE'), ).add_selection( brush ) text = alt.Chart(df).transform_filter(brush).mark_text( align='right', # baseline='top', # dx=1500 dx=750, dy=-12, fontSize=24, fontWeight=800, ).encode( # x='max(x):Q', y='mean(y):Q', # dy=alt.value(10), text=alt.Text('mean(y):Q', format='.2f'), ) average = alt.Chart(df).mark_rule(color='black', strokeDash=[5, 5]).encode( y='mean(y):Q', # size=alt.SizeValue(3), ).transform_filter( brush ) # chart = alt.layer(line, average, text) chart = line elif st.session_state.mode == 'User selection': brush = alt.selection(type='interval', encodings=['x']) line = alt.Chart(df).mark_line().encode( # https://www.rdocumentation.org/packages/vegalite/versions/0.6.1/topics/mark_line x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')), y=alt.Y('y:Q', axis=alt.Axis(labels=False, tickSize=0, title='')), # color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')), color=alt.value('#00C7BE'), ).add_selection( brush ) text = alt.Chart(df).transform_filter(brush).mark_text( align='right', # baseline='top', # dx=1500 dx=750, dy=-12, fontSize=24, fontWeight=800, ).encode( # x='max(x):Q', y='mean(y):Q', # dy=alt.value(10), text=alt.Text('mean(y):Q', format='.2f'), ) average = alt.Chart(df).mark_rule(color='black', strokeDash=[5, 5]).encode( y='mean(y):Q', # size=alt.SizeValue(3), ).transform_filter( brush ) # chart = alt.layer(line, average, text) chart = line return chart.properties(width=1250, height=500).configure_axis(grid=False, domain=False).configure_view(strokeOpacity=0) # return line def max_subarray(arr, k): n = len(arr) if (n < k): st.write('Video too short') res = 0 left = 0 right = k for i in range(k): res += arr[i] curr_sum = res for i in range(k, n): curr_sum += arr[i] - arr[i - k] if curr_sum > res: res = curr_sum left = i - k right = i return res, left, right def edit_video(template, df_all): video_path = f'videos/{st.session_state.domain.lower()}.mp4' if template == 'Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)': res, left, right = max_subarray(df_all['y'].tolist(), 7) video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) fps = video.fps x_dim = st.session_state.x_dim y_dim = st.session_state.y_dim music_path = 'music/coming-in-hot.mp3' blank1 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.6) flash1 = video.subclip(t_start=0, t_end=1.2) blank2 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) flash2 = video.subclip(t_start=1.3, t_end=1.4) blank3 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) flash3 = video.subclip(t_start=1.5, t_end=3.3) blank4 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) flash4 = video.subclip(t_start=3.4, t_end=3.5) blank5 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) flash5 = video.subclip(t_start=3.6, t_end=4.6) blank6 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) flash6 = video.subclip(t_start=4.7, t_end=4.8) blank7 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) highlight = video.subclip(t_start=4.9, t_end=6.384) output = concatenate_videoclips([blank1, flash1, blank2, flash2, blank3, flash3, blank4, flash4, blank5, flash5, blank6, flash6, blank7, highlight]) elif template == 'Thinking Out Loud Cypher by Jermsego (hype, 8 seconds)': res, left, right = max_subarray(df_all['y'].tolist(), 7) video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) fps = video.fps x_dim = st.session_state.x_dim y_dim = st.session_state.y_dim music_path = 'music/thinking-out-loud.mp3' blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=1.6) highlight = video.subclip(t_start=0, t_end=6.852) output = concatenate_videoclips([blank, highlight]) elif template == 'Sheesh by Surfaces (upbeat, 10 seconds)': res, left, right = max_subarray(df_all['y'].tolist(), 8) video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) fps = video.fps x_dim = st.session_state.x_dim y_dim = st.session_state.y_dim music_path = 'music/sheesh.mp3' blank1 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=3.5) flash1 = video.subclip(t_start=0, t_end=0.1) blank2 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) flash2 = video.subclip(t_start=0.2, t_end=0.3) blank3 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) flash3 = video.subclip(t_start=0.4, t_end=0.5) blank4 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) flash4 = video.subclip(t_start=0.6, t_end=0.7) blank5 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.9) highlight = video.subclip(t_start=1.6, t_end=7.18408163265) output = concatenate_videoclips([blank1, flash1, blank2, flash2, blank3, flash3, blank4, flash4, blank5, highlight]) elif template == 'Moon by Kid Francescoli (tranquil, 10 seconds)': res, left, right = max_subarray(df_all['y'].tolist(), 9) video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) fps = video.fps x_dim = st.session_state.x_dim y_dim = st.session_state.y_dim music_path = 'music/and-it-went-like.mp3' blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=1.9) highlight = video.subclip(t_start=0, t_end=8.132) output = concatenate_videoclips([blank, highlight]) elif template == 'Ready Set by Joey Valence & Brae (old school, 10 seconds)': res, left, right = max_subarray(df_all['y'].tolist(), 11) video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) fps = video.fps x_dim = st.session_state.x_dim y_dim = st.session_state.y_dim music_path = 'music/ready-set.mp3' highlight = video.subclip(t_start=0, t_end=10.512) output = highlight elif template == 'Lovewave by The 1-800 (tranquil, 13 seconds)': res, left, right = max_subarray(df_all['y'].tolist(), 12) video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) fps = video.fps x_dim = st.session_state.x_dim y_dim = st.session_state.y_dim music_path = 'music/lovewave.mp3' blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=2.1) highlight = video.subclip(t_start=0, t_end=11.58) output = concatenate_videoclips([blank, highlight]) elif template == 'And It Sounds Like by Forrest Nolan (tranquil, 17 seconds)': res, left, right = max_subarray(df_all['y'].tolist(), 16) video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) fps = video.fps x_dim = st.session_state.x_dim y_dim = st.session_state.y_dim music_path = 'music/and-it-sounds-like.mp3' blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=2) highlight = video.subclip(t_start=0, t_end=15.928) output = concatenate_videoclips([blank, highlight]) elif template == 'Comfort Chain by Instupendo (lofi, 18 seconds)': res, left, right = max_subarray(df_all['y'].tolist(), 19) video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) fps = video.fps x_dim = st.session_state.x_dim y_dim = st.session_state.y_dim music_path = 'music/comfort-chain.mp3' highlight = video.subclip(t_start=0, t_end=18.432000000000002) output = highlight # st.write(res, left, right) song = AudioFileClip(music_path) output = output.set_audio(song) output.write_videofile('output.mp4', temp_audiofile='temp.m4a', remove_temp=True, audio_codec='aac', logger=None, fps=fps) st.video('output.mp4') # return output def crop_video(df_all, left, right): video_path = f'videos/{st.session_state.domain.lower()}.mp4' video = VideoFileClip(video_path) fps = video.fps music_path = 'music/loop.mp3' song = AudioFileClip(music_path) video = video.set_audio(song) output = video.subclip(t_start=left, t_end=right) output.write_videofile('output.mp4', temp_audiofile='temp.m4a', remove_temp=True, audio_codec='aac', logger=None, fps=fps) st.video('output.mp4') # return output st.set_page_config(page_title='Videogenic', page_icon = '✨', layout = 'wide', initial_sidebar_state = 'collapsed') hide_streamlit_style = """ """ st.markdown(hide_streamlit_style, unsafe_allow_html=True) # clustrmaps = """ # # """ # st.markdown(clustrmaps, unsafe_allow_html=True) # ss = SessionState.get(url=None, id=None, input=None, file_name=None, video=None, video_name=None, video_frames=None, video_features=None, fps=None, mode=None, query=None, progress=1) st.title('Videogenic ✨') if 'progress' not in st.session_state: st.session_state.progress = 1 # mode = 'play' # mode = 'brush' # mode = 'select' if st.session_state.progress == 1: with st.spinner('Loading model...'): load_model() domain = st.selectbox('Select video',('Skydiving', 'Surfing', 'Skateboarding')) # Entire journey, montage, vlog if 'domain' not in st.session_state: st.session_state.domain = domain st.session_state.domain = domain if st.button('Process video'): with st.spinner('Processing video...'): video_name = f'videos/{st.session_state.domain.lower()}.mp4' video_file = open(video_name, 'rb') video_bytes = video_file.read() if 'video' not in st.session_state: st.session_state.video = video_bytes st.session_state.video = video_bytes # st.video(st.session_state.video) video_frames, fps, x_dim, y_dim = video_to_frames(video_name) # first run; video_to_info np.save(f'files/{st.session_state.domain.lower()}.npy', video_frames) fps, x_dim, y_dim = video_to_info(video_name) video_frames = np.load(f'files/{st.session_state.domain.lower()}.npy', allow_pickle=True) if 'video_frames' not in st.session_state: st.session_state.video_frames = video_frames st.session_state.fps = fps st.session_state.x_dim = x_dim st.session_state.y_dim = y_dim st.session_state.video_frames = video_frames st.session_state.fps = fps st.session_state.x_dim = x_dim st.session_state.y_dim = y_dim print('Extracted frames') encoded_frames = encode_frames(video_frames) # first run np.save(f'files/{st.session_state.domain.lower()}_features.npy', encoded_frames) encoded_frames = np.load(f'files/{st.session_state.domain.lower()}_features.npy', allow_pickle=True) if 'video_features' not in st.session_state: # st.session_state.video_features = encoded_frames st.session_state.video_features = encoded_frames st.session_state.video_features = encoded_frames print('Encoded frames') st.session_state.progress = 2 # with open('activities.txt') as f: # activities_list = [line.rstrip('\n') for line in f] # keywords = classify_activity(st.session_state.video_features, activities_list) # st.write(keywords) if st.session_state.progress == 2: mode = st.radio('Select mode', ('Automatic', 'User selection')) if 'mode' not in st.session_state: st.session_state.mode = mode st.session_state.mode = mode # keywords = list(st.text_input('Enter topic').split(',')) # if st.button('Compute scores') and keywords is not None: with st.spinner('Computing highlight scores...'): keyword = st.session_state.domain.lower() df_list = [] # for keyword in keywords: img_set = get_photos(keyword) similarities = compute_scores(img_set, st.session_state.video_features, keyword) # st.write(similarities) df = make_df(similarities) df_list.append(df) df_all = pd.concat(df_list, ignore_index=True, sort=False) if 'df_all' not in st.session_state: st.session_state.df_all = df_all st.session_state.df_all = df_all # st.write(df_all) # highlight_length = 7.033 # st.write(st.session_state.fps) with st.spinner('Visualizing highlight scores...'): selection = altair_component(draw_chart(df_all, st.session_state.mode)) print(selection) if 'selection' not in st.session_state: st.session_state.selection = selection st.session_state.selection = selection # if '_vgsid_' in selection: # # the ids start at 1 # st.write(df.iloc[[selection['_vgsid_'][0] - 1]]) # else: # st.info('Hover over the chart above to see details about the Penguin here.') # if 'x' in selection: # # the ids start at 1 # st.write(selection['x']) # chart = draw_chart(df_all, mode) # st.altair_chart(chart, use_container_width=False) # st.session_state.progress = 3 # if st.session_state.progress == 3: if st.session_state.mode == 'Automatic': # template = st.selectbox('Select template', ['Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)', 'Thinking Out Loud Cypher by Jermsego (hype, 8 seconds)', 'Sheesh by Surfaces (upbeat, 10 seconds)', # 'Moon by Kid Francescoli (tranquil, 10 seconds)', 'Ready Set by Joey Valence & Brae (old school, 10 seconds)', 'Lovewave by The 1-800 (tranquil, 13 seconds)', # 'And It Sounds Like by Forrest Nolan (tranquil, 17 seconds)', 'Comfort Chain by Instupendo (lofi, 18 seconds)']) template = st.selectbox('Select template', ['Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)', 'Sheesh by Surfaces (upbeat, 10 seconds)', 'Lovewave by The 1-800 (tranquil, 13 seconds)']) if st.button('Generate video'): with st.spinner('Generating highlight video...'): edit_video(template, st.session_state.df_all) st.balloons() elif st.session_state.mode == 'User selection': if st.button('Generate video'): left = st.session_state.selection['x'][0] right = st.session_state.selection['x'][1] with st.spinner('Generating highlight video...'): crop_video(st.session_state.df_all, left, right) st.balloons()