Spaces:
Runtime error
Runtime error
import streamlit as st | |
# from pytube import YouTube | |
# from pytube import extract | |
import cv2 | |
from PIL import Image | |
import clip as openai_clip | |
import torch | |
import math | |
import numpy as np | |
import tempfile | |
# from humanfriendly import format_timespan | |
import json | |
import sys | |
from random import randrange | |
import logging | |
# from pyunsplash import PyUnsplash | |
import requests | |
import io | |
from io import BytesIO | |
import base64 | |
import altair as alt | |
from streamlit_vega_lite import altair_component | |
import pandas as pd | |
from datetime import timedelta | |
import math | |
from decord import VideoReader, cpu, gpu | |
from moviepy.video.io.VideoFileClip import VideoFileClip | |
from moviepy.audio.io.AudioFileClip import AudioFileClip | |
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip | |
from moviepy.editor import * | |
import glob | |
# @st.cache(show_spinner=False) | |
def load_model(): | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model, preprocess = openai_clip.load('ViT-B/32', device=device) | |
if 'model' not in st.session_state: | |
st.session_state.model = model | |
st.session_state.preprocess = preprocess | |
st.session_state.device = device | |
st.session_state.model = model | |
st.session_state.preprocess = preprocess | |
st.session_state.device = device | |
def fetch_video(url): | |
yt = YouTube(url) | |
streams = yt.streams.filter(adaptive=True, subtype='mp4', resolution='360p', only_video=True) | |
length = yt.length | |
if length >= 300: | |
st.error('Please find a YouTube video shorter than 5 minutes. Sorry about this, the server capacity is limited for the time being.') | |
st.stop() | |
video = streams[0] | |
return video, video.url | |
# @st.cache() | |
# def extract_frames(video): | |
# frames = [] | |
# capture = cv2.VideoCapture(video) | |
# fps = capture.get(cv2.CAP_PROP_FPS) | |
# current_frame = 0 | |
# while capture.isOpened(): | |
# ret, frame = capture.read() | |
# if ret == True: | |
# frames.append(Image.fromarray(frame[:, :, ::-1])) | |
# else: | |
# break | |
# current_frame += fps | |
# capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame) | |
# # print(f'Frames extracted: {len(frames)}') | |
# return frames, fps | |
# @st.cache() | |
def video_to_frames(video): | |
vr = VideoReader(video) | |
frames = [] | |
frame_count = len(vr) | |
fps = vr.get_avg_fps() | |
for i in range(0, frame_count, round(fps)): | |
# for i in range(0, frame_count): | |
frame = vr[i].asnumpy() | |
y_dim = frame.shape[0] | |
x_dim = frame.shape[1] | |
frames.append(Image.fromarray(frame)) | |
return frames, fps, x_dim, y_dim | |
def video_to_info(video): | |
vr = VideoReader(video) | |
frames = [] | |
frame_count = len(vr) | |
fps = vr.get_avg_fps() | |
frame = vr[0].asnumpy() | |
y_dim = frame.shape[0] | |
x_dim = frame.shape[1] | |
return fps, x_dim, y_dim | |
# @st.cache() | |
def encode_frames(video_frames): | |
batch_size = 256 | |
batches = math.ceil(len(video_frames) / batch_size) | |
video_features = torch.empty([0, 512], dtype=torch.float16).to(st.session_state.device) | |
for i in range(batches): | |
batch_frames = video_frames[i*batch_size : (i+1)*batch_size] | |
batch_preprocessed = torch.stack([st.session_state.preprocess(frame) for frame in batch_frames]).to(st.session_state.device) | |
with torch.no_grad(): | |
batch_features = st.session_state.model.encode_image(batch_preprocessed) | |
batch_features /= batch_features.norm(dim=-1, keepdim=True) | |
video_features = torch.cat((video_features, batch_features)) | |
# print(f'Features: {video_features.shape}') | |
return video_features | |
def classify_activity(video_features, activities_list): | |
text = torch.cat([openai_clip.tokenize( | |
f'{activity}') for activity in activities_list]).to(st.session_state.device) | |
with torch.no_grad(): | |
text_features = st.session_state.model.encode_text(text) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
logit_scale = st.session_state.model.logit_scale.exp() | |
video_features = torch.from_numpy(video_features) | |
similarities = (logit_scale * video_features @ | |
text_features.t()).softmax(dim=-1) | |
probs, word_idxs = similarities[0].topk(5) | |
primary_activity = [] | |
for prob, word_idx in zip(probs, word_idxs): | |
primary_activity.append(activities_list[word_idx]) | |
# primary_activity = activities_list[word_idx] | |
return primary_activity | |
def encode_photos(photos): | |
batch_size = 256 | |
batches = math.ceil(len(photos) / batch_size) | |
video_features = torch.empty([0, 512], dtype=torch.float16).to(st.session_state.device) | |
for i in range(batches): | |
batch_frames = photos[i*batch_size : (i+1)*batch_size] | |
batch_preprocessed = torch.stack([st.session_state.preprocess(Image.open(frame)) for frame in batch_frames]).to(st.session_state.device) | |
with torch.no_grad(): | |
batch_features = st.session_state.model.encode_image(batch_preprocessed) | |
batch_features /= batch_features.norm(dim=-1, keepdim=True) | |
video_features = torch.cat((video_features, batch_features)) | |
# print(f'Features: {video_features.shape}') | |
return video_features | |
def img_to_bytes(img): | |
img_byte_arr = io.BytesIO() | |
img.save(img_byte_arr, format='JPEG') | |
img_byte_arr = img_byte_arr.getvalue() | |
return img_byte_arr | |
def normalize(vector): | |
return (vector - np.min(vector)) / (np.max(vector) - np.min(vector)) | |
def format_img(img): | |
size = 150, 150 | |
# img = Image.fromarray(img) | |
img.thumbnail(size, Image.Resampling.LANCZOS) | |
output = io.BytesIO() | |
img.save(output, format='PNG') | |
encoded_string = f'data:image/png;base64,{base64.b64encode(output.getvalue()).decode()}' | |
return encoded_string | |
def get_photos(keyword): | |
photo_collection = [] | |
for filename in glob.glob(f'photos/{st.session_state.domain.lower()}/*.jpeg')[:1]: | |
photo = Image.open(filename) | |
photo_collection.append(photo) | |
return photo_collection | |
# # api_key = 'hzcKZ0e4we95wSd8_ip2zTB3m2DrOMWehAxrYjqjwg0' | |
# api_key = 'fZ1nE7Y4NC-iYGmqgv-WuyM8m9p0LroCdAOZOR6tyho' | |
# unsplash_search = PyUnsplash(api_key=api_key) | |
# logging.getLogger('pyunsplash').setLevel(logging.DEBUG) | |
# search = unsplash_search.search(type_='photos', query=keyword) # per_page | |
# photo_collection = [] | |
# # st.markdown(f'**Unsplash photos for `{keyword}`**') | |
# for result in search.entries: | |
# photo_url = result.link_download | |
# response = requests.get(photo_url) | |
# photo = Image.open(BytesIO(response.content)) | |
# # st.image(photo, width=200) | |
# photo_collection.append(photo) | |
# return photo_collection | |
def display_results(best_photo_idx): | |
st.markdown('**Top 10 highlights**') | |
result_arr = [] | |
for frame_id in best_photo_idx: | |
result = st.session_state.video_frames[frame_id] | |
st.image(result) | |
return result_arr | |
def make_df(similarities): | |
similarities = similarities | |
df = pd.DataFrame() | |
df['keyword'] = [keyword] * len(similarities) | |
df['x'] = [i for i, _ in enumerate(similarities)] | |
df['y'] = normalize(np.power(similarities, 8)) | |
df['image'] = [format_img(frame) for frame in st.session_state.video_frames] | |
return df | |
# @st.cache() | |
def compute_scores(search_query, video_features, text_query, display_results_count=10): | |
sum_photo = torch.zeros(1, 512) | |
for photo in search_query: | |
with torch.no_grad(): | |
image_features = st.session_state.model.encode_image(st.session_state.preprocess(photo).unsqueeze(0).to(st.session_state.device)) | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
sum_photo += sum_photo + image_features | |
avg_photo = sum_photo / len(search_query) | |
video_features = torch.from_numpy(video_features) | |
similarities = (100.0 * video_features @ avg_photo.T) | |
# values, best_photo_idx = similarities.topk(display_results_count, dim=0) | |
# display_results(best_photo_idx) | |
return similarities.cpu().numpy() | |
def avenir(): | |
font = 'Avenir' | |
return { | |
'config' : { | |
'title': {'font': font}, | |
'axis': { | |
'labelFont': font, | |
'titleFont': font | |
} | |
} | |
} | |
alt.themes.register('avenir', avenir) | |
alt.themes.enable('avenir') | |
# TODO: Make playhead scores and average according to keyword | |
# TODO: Maximum interval selection | |
# TODO: Interactive legend https://altair-viz.github.io/gallery/interactive_legend.html | |
# TODO: Multi-line highlight https://altair-viz.github.io/gallery/multiline_highlight.html | |
def draw_chart(df, mode): | |
if st.session_state.mode == 'Automatic': | |
nearest = alt.selection(type='single', nearest=True, on='mouseover', empty='none') | |
line = alt.Chart(df).mark_line().encode( | |
x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')), | |
y=alt.Y('y', axis=alt.Axis(labels=False, tickSize=0, title='')), | |
# color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')), | |
color=alt.value('#00C7BE'), | |
# color=alt.Color('#9b59b6'), | |
) | |
selectors = alt.Chart(df).mark_point().encode( | |
x='x:Q', | |
opacity=alt.value(0), | |
).add_selection( | |
nearest | |
) | |
rules = alt.Chart(df).mark_rule(color='black').encode( | |
x='x:Q', | |
).transform_filter( | |
nearest | |
) | |
points = line.mark_point().encode( | |
opacity=alt.condition(nearest, alt.value(1), alt.value(0)) | |
) | |
text = line.mark_text(align='center', yOffset=-110, fontSize=16).encode( | |
text=alt.condition(nearest, 'y:N', alt.value(' ')), | |
color=alt.value('#000000'), | |
# fontSize=30 | |
).transform_calculate(y=f'format(datum.y, ".2f")') | |
image = line.mark_image(align='center', width=150, height=150, yOffset=-60).encode( | |
url=alt.condition(nearest, 'image', alt.value(' ')) | |
) | |
chart = alt.layer(line, selectors, points, rules, text, image) | |
elif st.session_state.mode == 'brush': | |
brush = alt.selection(type='interval', encodings=['x']) | |
line = alt.Chart(df).mark_line().encode( # https://www.rdocumentation.org/packages/vegalite/versions/0.6.1/topics/mark_line | |
x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')), | |
y=alt.Y('y:Q', axis=alt.Axis(labels=False, tickSize=0, title='')), | |
# color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')), | |
color=alt.value('#00C7BE'), | |
).add_selection( | |
brush | |
) | |
text = alt.Chart(df).transform_filter(brush).mark_text( | |
align='right', | |
# baseline='top', | |
# dx=1500 | |
dx=750, | |
dy=-12, | |
fontSize=24, | |
fontWeight=800, | |
).encode( | |
# x='max(x):Q', | |
y='mean(y):Q', | |
# dy=alt.value(10), | |
text=alt.Text('mean(y):Q', format='.2f'), | |
) | |
average = alt.Chart(df).mark_rule(color='black', strokeDash=[5, 5]).encode( | |
y='mean(y):Q', | |
# size=alt.SizeValue(3), | |
).transform_filter( | |
brush | |
) | |
# chart = alt.layer(line, average, text) | |
chart = line | |
elif st.session_state.mode == 'User selection': | |
brush = alt.selection(type='interval', encodings=['x']) | |
line = alt.Chart(df).mark_line().encode( # https://www.rdocumentation.org/packages/vegalite/versions/0.6.1/topics/mark_line | |
x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')), | |
y=alt.Y('y:Q', axis=alt.Axis(labels=False, tickSize=0, title='')), | |
# color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')), | |
color=alt.value('#00C7BE'), | |
).add_selection( | |
brush | |
) | |
text = alt.Chart(df).transform_filter(brush).mark_text( | |
align='right', | |
# baseline='top', | |
# dx=1500 | |
dx=750, | |
dy=-12, | |
fontSize=24, | |
fontWeight=800, | |
).encode( | |
# x='max(x):Q', | |
y='mean(y):Q', | |
# dy=alt.value(10), | |
text=alt.Text('mean(y):Q', format='.2f'), | |
) | |
average = alt.Chart(df).mark_rule(color='black', strokeDash=[5, 5]).encode( | |
y='mean(y):Q', | |
# size=alt.SizeValue(3), | |
).transform_filter( | |
brush | |
) | |
# chart = alt.layer(line, average, text) | |
chart = line | |
return chart.properties(width=1250, height=500).configure_axis(grid=False, domain=False).configure_view(strokeOpacity=0) | |
# return line | |
def max_subarray(arr, k): | |
n = len(arr) | |
if (n < k): | |
st.write('Video too short') | |
res = 0 | |
left = 0 | |
right = k | |
for i in range(k): | |
res += arr[i] | |
curr_sum = res | |
for i in range(k, n): | |
curr_sum += arr[i] - arr[i - k] | |
if curr_sum > res: | |
res = curr_sum | |
left = i - k | |
right = i | |
return res, left, right | |
def edit_video(template, df_all): | |
video_path = f'videos/{st.session_state.domain.lower()}.mp4' | |
if template == 'Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)': | |
res, left, right = max_subarray(df_all['y'].tolist(), 7) | |
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) | |
fps = video.fps | |
x_dim = st.session_state.x_dim | |
y_dim = st.session_state.y_dim | |
music_path = 'music/coming-in-hot.mp3' | |
blank1 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.6) | |
flash1 = video.subclip(t_start=0, t_end=1.2) | |
blank2 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) | |
flash2 = video.subclip(t_start=1.3, t_end=1.4) | |
blank3 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) | |
flash3 = video.subclip(t_start=1.5, t_end=3.3) | |
blank4 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) | |
flash4 = video.subclip(t_start=3.4, t_end=3.5) | |
blank5 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) | |
flash5 = video.subclip(t_start=3.6, t_end=4.6) | |
blank6 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) | |
flash6 = video.subclip(t_start=4.7, t_end=4.8) | |
blank7 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) | |
highlight = video.subclip(t_start=4.9, t_end=6.384) | |
output = concatenate_videoclips([blank1, flash1, blank2, flash2, blank3, flash3, blank4, flash4, blank5, flash5, blank6, flash6, blank7, highlight]) | |
elif template == 'Thinking Out Loud Cypher by Jermsego (hype, 8 seconds)': | |
res, left, right = max_subarray(df_all['y'].tolist(), 7) | |
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) | |
fps = video.fps | |
x_dim = st.session_state.x_dim | |
y_dim = st.session_state.y_dim | |
music_path = 'music/thinking-out-loud.mp3' | |
blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=1.6) | |
highlight = video.subclip(t_start=0, t_end=6.852) | |
output = concatenate_videoclips([blank, highlight]) | |
elif template == 'Sheesh by Surfaces (upbeat, 10 seconds)': | |
res, left, right = max_subarray(df_all['y'].tolist(), 8) | |
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) | |
fps = video.fps | |
x_dim = st.session_state.x_dim | |
y_dim = st.session_state.y_dim | |
music_path = 'music/sheesh.mp3' | |
blank1 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=3.5) | |
flash1 = video.subclip(t_start=0, t_end=0.1) | |
blank2 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) | |
flash2 = video.subclip(t_start=0.2, t_end=0.3) | |
blank3 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) | |
flash3 = video.subclip(t_start=0.4, t_end=0.5) | |
blank4 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1) | |
flash4 = video.subclip(t_start=0.6, t_end=0.7) | |
blank5 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.9) | |
highlight = video.subclip(t_start=1.6, t_end=7.18408163265) | |
output = concatenate_videoclips([blank1, flash1, blank2, flash2, blank3, flash3, blank4, flash4, blank5, highlight]) | |
elif template == 'Moon by Kid Francescoli (tranquil, 10 seconds)': | |
res, left, right = max_subarray(df_all['y'].tolist(), 9) | |
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) | |
fps = video.fps | |
x_dim = st.session_state.x_dim | |
y_dim = st.session_state.y_dim | |
music_path = 'music/and-it-went-like.mp3' | |
blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=1.9) | |
highlight = video.subclip(t_start=0, t_end=8.132) | |
output = concatenate_videoclips([blank, highlight]) | |
elif template == 'Ready Set by Joey Valence & Brae (old school, 10 seconds)': | |
res, left, right = max_subarray(df_all['y'].tolist(), 11) | |
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) | |
fps = video.fps | |
x_dim = st.session_state.x_dim | |
y_dim = st.session_state.y_dim | |
music_path = 'music/ready-set.mp3' | |
highlight = video.subclip(t_start=0, t_end=10.512) | |
output = highlight | |
elif template == 'Lovewave by The 1-800 (tranquil, 13 seconds)': | |
res, left, right = max_subarray(df_all['y'].tolist(), 12) | |
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) | |
fps = video.fps | |
x_dim = st.session_state.x_dim | |
y_dim = st.session_state.y_dim | |
music_path = 'music/lovewave.mp3' | |
blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=2.1) | |
highlight = video.subclip(t_start=0, t_end=11.58) | |
output = concatenate_videoclips([blank, highlight]) | |
elif template == 'And It Sounds Like by Forrest Nolan (tranquil, 17 seconds)': | |
res, left, right = max_subarray(df_all['y'].tolist(), 16) | |
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) | |
fps = video.fps | |
x_dim = st.session_state.x_dim | |
y_dim = st.session_state.y_dim | |
music_path = 'music/and-it-sounds-like.mp3' | |
blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=2) | |
highlight = video.subclip(t_start=0, t_end=15.928) | |
output = concatenate_videoclips([blank, highlight]) | |
elif template == 'Comfort Chain by Instupendo (lofi, 18 seconds)': | |
res, left, right = max_subarray(df_all['y'].tolist(), 19) | |
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right) | |
fps = video.fps | |
x_dim = st.session_state.x_dim | |
y_dim = st.session_state.y_dim | |
music_path = 'music/comfort-chain.mp3' | |
highlight = video.subclip(t_start=0, t_end=18.432000000000002) | |
output = highlight | |
# st.write(res, left, right) | |
song = AudioFileClip(music_path) | |
output = output.set_audio(song) | |
output.write_videofile('output.mp4', temp_audiofile='temp.m4a', remove_temp=True, audio_codec='aac', logger=None, fps=fps) | |
st.video('output.mp4') | |
# return output | |
def crop_video(df_all, left, right): | |
video_path = f'videos/{st.session_state.domain.lower()}.mp4' | |
video = VideoFileClip(video_path) | |
fps = video.fps | |
music_path = 'music/loop.mp3' | |
song = AudioFileClip(music_path) | |
video = video.set_audio(song) | |
output = video.subclip(t_start=left, t_end=right) | |
output.write_videofile('output.mp4', temp_audiofile='temp.m4a', remove_temp=True, audio_codec='aac', logger=None, fps=fps) | |
st.video('output.mp4') | |
# return output | |
st.set_page_config(page_title='Videogenic', page_icon = '✨', layout = 'wide', initial_sidebar_state = 'collapsed') | |
hide_streamlit_style = """ | |
<style> | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
* {font-family: Avenir; cursor: pointer;} | |
.css-gma2qf {display: flex; justify-content: center; font-size: 42px; font-weight: bold;} | |
a:link {text-decoration: none;} | |
a:hover {text-decoration: none;} | |
.st-ba {font-family: Avenir;} | |
</style> | |
""" | |
st.markdown(hide_streamlit_style, unsafe_allow_html=True) | |
# clustrmaps = """ | |
# <a href="https://clustrmaps.com/site/1bham" target="_blank" title="Visit tracker"><img src="//www.clustrmaps.com/map_v2.png?d=NhNk5g9hy6Y06nqo7RirhHvZSr89uSS8rPrt471wAXw&cl=ffffff" width="0" height="0"></a> | |
# """ | |
# st.markdown(clustrmaps, unsafe_allow_html=True) | |
# ss = SessionState.get(url=None, id=None, input=None, file_name=None, video=None, video_name=None, video_frames=None, video_features=None, fps=None, mode=None, query=None, progress=1) | |
st.title('Videogenic ✨') | |
if 'progress' not in st.session_state: | |
st.session_state.progress = 1 | |
# mode = 'play' | |
# mode = 'brush' | |
# mode = 'select' | |
if st.session_state.progress == 1: | |
with st.spinner('Loading model...'): | |
load_model() | |
domain = st.selectbox('Select video',('Skydiving', 'Surfing', 'Skateboarding')) # Entire journey, montage, vlog | |
if 'domain' not in st.session_state: | |
st.session_state.domain = domain | |
st.session_state.domain = domain | |
if st.button('Process video'): | |
with st.spinner('Processing video...'): | |
video_name = f'videos/{st.session_state.domain.lower()}.mp4' | |
video_file = open(video_name, 'rb') | |
video_bytes = video_file.read() | |
if 'video' not in st.session_state: | |
st.session_state.video = video_bytes | |
st.session_state.video = video_bytes | |
# st.video(st.session_state.video) | |
video_frames, fps, x_dim, y_dim = video_to_frames(video_name) # first run; video_to_info | |
np.save(f'files/{st.session_state.domain.lower()}.npy', video_frames) | |
fps, x_dim, y_dim = video_to_info(video_name) | |
video_frames = np.load(f'files/{st.session_state.domain.lower()}.npy', allow_pickle=True) | |
if 'video_frames' not in st.session_state: | |
st.session_state.video_frames = video_frames | |
st.session_state.fps = fps | |
st.session_state.x_dim = x_dim | |
st.session_state.y_dim = y_dim | |
st.session_state.video_frames = video_frames | |
st.session_state.fps = fps | |
st.session_state.x_dim = x_dim | |
st.session_state.y_dim = y_dim | |
print('Extracted frames') | |
encoded_frames = encode_frames(video_frames) # first run | |
np.save(f'files/{st.session_state.domain.lower()}_features.npy', encoded_frames) | |
encoded_frames = np.load(f'files/{st.session_state.domain.lower()}_features.npy', allow_pickle=True) | |
if 'video_features' not in st.session_state: | |
# st.session_state.video_features = encoded_frames | |
st.session_state.video_features = encoded_frames | |
st.session_state.video_features = encoded_frames | |
print('Encoded frames') | |
st.session_state.progress = 2 | |
# with open('activities.txt') as f: | |
# activities_list = [line.rstrip('\n') for line in f] | |
# keywords = classify_activity(st.session_state.video_features, activities_list) | |
# st.write(keywords) | |
if st.session_state.progress == 2: | |
mode = st.radio('Select mode', ('Automatic', 'User selection')) | |
if 'mode' not in st.session_state: | |
st.session_state.mode = mode | |
st.session_state.mode = mode | |
# keywords = list(st.text_input('Enter topic').split(',')) | |
# if st.button('Compute scores') and keywords is not None: | |
with st.spinner('Computing highlight scores...'): | |
keyword = st.session_state.domain.lower() | |
df_list = [] | |
# for keyword in keywords: | |
img_set = get_photos(keyword) | |
similarities = compute_scores(img_set, st.session_state.video_features, keyword) | |
# st.write(similarities) | |
df = make_df(similarities) | |
df_list.append(df) | |
df_all = pd.concat(df_list, ignore_index=True, sort=False) | |
if 'df_all' not in st.session_state: | |
st.session_state.df_all = df_all | |
st.session_state.df_all = df_all | |
# st.write(df_all) | |
# highlight_length = 7.033 | |
# st.write(st.session_state.fps) | |
with st.spinner('Visualizing highlight scores...'): | |
selection = altair_component(draw_chart(df_all, st.session_state.mode)) | |
print(selection) | |
if 'selection' not in st.session_state: | |
st.session_state.selection = selection | |
st.session_state.selection = selection | |
# if '_vgsid_' in selection: | |
# # the ids start at 1 | |
# st.write(df.iloc[[selection['_vgsid_'][0] - 1]]) | |
# else: | |
# st.info('Hover over the chart above to see details about the Penguin here.') | |
# if 'x' in selection: | |
# # the ids start at 1 | |
# st.write(selection['x']) | |
# chart = draw_chart(df_all, mode) | |
# st.altair_chart(chart, use_container_width=False) | |
# st.session_state.progress = 3 | |
# if st.session_state.progress == 3: | |
if st.session_state.mode == 'Automatic': | |
# template = st.selectbox('Select template', ['Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)', 'Thinking Out Loud Cypher by Jermsego (hype, 8 seconds)', 'Sheesh by Surfaces (upbeat, 10 seconds)', | |
# 'Moon by Kid Francescoli (tranquil, 10 seconds)', 'Ready Set by Joey Valence & Brae (old school, 10 seconds)', 'Lovewave by The 1-800 (tranquil, 13 seconds)', | |
# 'And It Sounds Like by Forrest Nolan (tranquil, 17 seconds)', 'Comfort Chain by Instupendo (lofi, 18 seconds)']) | |
template = st.selectbox('Select template', ['Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)', 'Sheesh by Surfaces (upbeat, 10 seconds)', 'Lovewave by The 1-800 (tranquil, 13 seconds)']) | |
if st.button('Generate video'): | |
with st.spinner('Generating highlight video...'): | |
edit_video(template, st.session_state.df_all) | |
st.balloons() | |
elif st.session_state.mode == 'User selection': | |
if st.button('Generate video'): | |
left = st.session_state.selection['x'][0] | |
right = st.session_state.selection['x'][1] | |
with st.spinner('Generating highlight video...'): | |
crop_video(st.session_state.df_all, left, right) | |
st.balloons() |