videogenic / videogenic.py
chuanenlin's picture
Update videogenic.py
83f868a
import streamlit as st
# from pytube import YouTube
# from pytube import extract
import cv2
from PIL import Image
import clip as openai_clip
import torch
import math
import numpy as np
import tempfile
# from humanfriendly import format_timespan
import json
import sys
from random import randrange
import logging
# from pyunsplash import PyUnsplash
import requests
import io
from io import BytesIO
import base64
import altair as alt
from streamlit_vega_lite import altair_component
import pandas as pd
from datetime import timedelta
import math
from decord import VideoReader, cpu, gpu
from moviepy.video.io.VideoFileClip import VideoFileClip
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
from moviepy.editor import *
import glob
# @st.cache(show_spinner=False)
def load_model():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, preprocess = openai_clip.load('ViT-B/32', device=device)
if 'model' not in st.session_state:
st.session_state.model = model
st.session_state.preprocess = preprocess
st.session_state.device = device
st.session_state.model = model
st.session_state.preprocess = preprocess
st.session_state.device = device
def fetch_video(url):
yt = YouTube(url)
streams = yt.streams.filter(adaptive=True, subtype='mp4', resolution='360p', only_video=True)
length = yt.length
if length >= 300:
st.error('Please find a YouTube video shorter than 5 minutes. Sorry about this, the server capacity is limited for the time being.')
st.stop()
video = streams[0]
return video, video.url
# @st.cache()
# def extract_frames(video):
# frames = []
# capture = cv2.VideoCapture(video)
# fps = capture.get(cv2.CAP_PROP_FPS)
# current_frame = 0
# while capture.isOpened():
# ret, frame = capture.read()
# if ret == True:
# frames.append(Image.fromarray(frame[:, :, ::-1]))
# else:
# break
# current_frame += fps
# capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
# # print(f'Frames extracted: {len(frames)}')
# return frames, fps
# @st.cache()
def video_to_frames(video):
vr = VideoReader(video)
frames = []
frame_count = len(vr)
fps = vr.get_avg_fps()
for i in range(0, frame_count, round(fps)):
# for i in range(0, frame_count):
frame = vr[i].asnumpy()
y_dim = frame.shape[0]
x_dim = frame.shape[1]
frames.append(Image.fromarray(frame))
return frames, fps, x_dim, y_dim
def video_to_info(video):
vr = VideoReader(video)
frames = []
frame_count = len(vr)
fps = vr.get_avg_fps()
frame = vr[0].asnumpy()
y_dim = frame.shape[0]
x_dim = frame.shape[1]
return fps, x_dim, y_dim
# @st.cache()
def encode_frames(video_frames):
batch_size = 256
batches = math.ceil(len(video_frames) / batch_size)
video_features = torch.empty([0, 512], dtype=torch.float16).to(st.session_state.device)
for i in range(batches):
batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
batch_preprocessed = torch.stack([st.session_state.preprocess(frame) for frame in batch_frames]).to(st.session_state.device)
with torch.no_grad():
batch_features = st.session_state.model.encode_image(batch_preprocessed)
batch_features /= batch_features.norm(dim=-1, keepdim=True)
video_features = torch.cat((video_features, batch_features))
# print(f'Features: {video_features.shape}')
return video_features
def classify_activity(video_features, activities_list):
text = torch.cat([openai_clip.tokenize(
f'{activity}') for activity in activities_list]).to(st.session_state.device)
with torch.no_grad():
text_features = st.session_state.model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
logit_scale = st.session_state.model.logit_scale.exp()
video_features = torch.from_numpy(video_features)
similarities = (logit_scale * video_features @
text_features.t()).softmax(dim=-1)
probs, word_idxs = similarities[0].topk(5)
primary_activity = []
for prob, word_idx in zip(probs, word_idxs):
primary_activity.append(activities_list[word_idx])
# primary_activity = activities_list[word_idx]
return primary_activity
def encode_photos(photos):
batch_size = 256
batches = math.ceil(len(photos) / batch_size)
video_features = torch.empty([0, 512], dtype=torch.float16).to(st.session_state.device)
for i in range(batches):
batch_frames = photos[i*batch_size : (i+1)*batch_size]
batch_preprocessed = torch.stack([st.session_state.preprocess(Image.open(frame)) for frame in batch_frames]).to(st.session_state.device)
with torch.no_grad():
batch_features = st.session_state.model.encode_image(batch_preprocessed)
batch_features /= batch_features.norm(dim=-1, keepdim=True)
video_features = torch.cat((video_features, batch_features))
# print(f'Features: {video_features.shape}')
return video_features
def img_to_bytes(img):
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
return img_byte_arr
def normalize(vector):
return (vector - np.min(vector)) / (np.max(vector) - np.min(vector))
def format_img(img):
size = 150, 150
# img = Image.fromarray(img)
img.thumbnail(size, Image.Resampling.LANCZOS)
output = io.BytesIO()
img.save(output, format='PNG')
encoded_string = f'data:image/png;base64,{base64.b64encode(output.getvalue()).decode()}'
return encoded_string
def get_photos(keyword):
photo_collection = []
for filename in glob.glob(f'photos/{st.session_state.domain.lower()}/*.jpeg')[:1]:
photo = Image.open(filename)
photo_collection.append(photo)
return photo_collection
# # api_key = 'hzcKZ0e4we95wSd8_ip2zTB3m2DrOMWehAxrYjqjwg0'
# api_key = 'fZ1nE7Y4NC-iYGmqgv-WuyM8m9p0LroCdAOZOR6tyho'
# unsplash_search = PyUnsplash(api_key=api_key)
# logging.getLogger('pyunsplash').setLevel(logging.DEBUG)
# search = unsplash_search.search(type_='photos', query=keyword) # per_page
# photo_collection = []
# # st.markdown(f'**Unsplash photos for `{keyword}`**')
# for result in search.entries:
# photo_url = result.link_download
# response = requests.get(photo_url)
# photo = Image.open(BytesIO(response.content))
# # st.image(photo, width=200)
# photo_collection.append(photo)
# return photo_collection
def display_results(best_photo_idx):
st.markdown('**Top 10 highlights**')
result_arr = []
for frame_id in best_photo_idx:
result = st.session_state.video_frames[frame_id]
st.image(result)
return result_arr
def make_df(similarities):
similarities = similarities
df = pd.DataFrame()
df['keyword'] = [keyword] * len(similarities)
df['x'] = [i for i, _ in enumerate(similarities)]
df['y'] = normalize(np.power(similarities, 8))
df['image'] = [format_img(frame) for frame in st.session_state.video_frames]
return df
# @st.cache()
def compute_scores(search_query, video_features, text_query, display_results_count=10):
sum_photo = torch.zeros(1, 512)
for photo in search_query:
with torch.no_grad():
image_features = st.session_state.model.encode_image(st.session_state.preprocess(photo).unsqueeze(0).to(st.session_state.device))
image_features /= image_features.norm(dim=-1, keepdim=True)
sum_photo += sum_photo + image_features
avg_photo = sum_photo / len(search_query)
video_features = torch.from_numpy(video_features)
similarities = (100.0 * video_features @ avg_photo.T)
# values, best_photo_idx = similarities.topk(display_results_count, dim=0)
# display_results(best_photo_idx)
return similarities.cpu().numpy()
def avenir():
font = 'Avenir'
return {
'config' : {
'title': {'font': font},
'axis': {
'labelFont': font,
'titleFont': font
}
}
}
alt.themes.register('avenir', avenir)
alt.themes.enable('avenir')
# TODO: Make playhead scores and average according to keyword
# TODO: Maximum interval selection
# TODO: Interactive legend https://altair-viz.github.io/gallery/interactive_legend.html
# TODO: Multi-line highlight https://altair-viz.github.io/gallery/multiline_highlight.html
@st.cache
def draw_chart(df, mode):
if st.session_state.mode == 'Automatic':
nearest = alt.selection(type='single', nearest=True, on='mouseover', empty='none')
line = alt.Chart(df).mark_line().encode(
x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')),
y=alt.Y('y', axis=alt.Axis(labels=False, tickSize=0, title='')),
# color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')),
color=alt.value('#00C7BE'),
# color=alt.Color('#9b59b6'),
)
selectors = alt.Chart(df).mark_point().encode(
x='x:Q',
opacity=alt.value(0),
).add_selection(
nearest
)
rules = alt.Chart(df).mark_rule(color='black').encode(
x='x:Q',
).transform_filter(
nearest
)
points = line.mark_point().encode(
opacity=alt.condition(nearest, alt.value(1), alt.value(0))
)
text = line.mark_text(align='center', yOffset=-110, fontSize=16).encode(
text=alt.condition(nearest, 'y:N', alt.value(' ')),
color=alt.value('#000000'),
# fontSize=30
).transform_calculate(y=f'format(datum.y, ".2f")')
image = line.mark_image(align='center', width=150, height=150, yOffset=-60).encode(
url=alt.condition(nearest, 'image', alt.value(' '))
)
chart = alt.layer(line, selectors, points, rules, text, image)
elif st.session_state.mode == 'brush':
brush = alt.selection(type='interval', encodings=['x'])
line = alt.Chart(df).mark_line().encode( # https://www.rdocumentation.org/packages/vegalite/versions/0.6.1/topics/mark_line
x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')),
y=alt.Y('y:Q', axis=alt.Axis(labels=False, tickSize=0, title='')),
# color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')),
color=alt.value('#00C7BE'),
).add_selection(
brush
)
text = alt.Chart(df).transform_filter(brush).mark_text(
align='right',
# baseline='top',
# dx=1500
dx=750,
dy=-12,
fontSize=24,
fontWeight=800,
).encode(
# x='max(x):Q',
y='mean(y):Q',
# dy=alt.value(10),
text=alt.Text('mean(y):Q', format='.2f'),
)
average = alt.Chart(df).mark_rule(color='black', strokeDash=[5, 5]).encode(
y='mean(y):Q',
# size=alt.SizeValue(3),
).transform_filter(
brush
)
# chart = alt.layer(line, average, text)
chart = line
elif st.session_state.mode == 'User selection':
brush = alt.selection(type='interval', encodings=['x'])
line = alt.Chart(df).mark_line().encode( # https://www.rdocumentation.org/packages/vegalite/versions/0.6.1/topics/mark_line
x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')),
y=alt.Y('y:Q', axis=alt.Axis(labels=False, tickSize=0, title='')),
# color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')),
color=alt.value('#00C7BE'),
).add_selection(
brush
)
text = alt.Chart(df).transform_filter(brush).mark_text(
align='right',
# baseline='top',
# dx=1500
dx=750,
dy=-12,
fontSize=24,
fontWeight=800,
).encode(
# x='max(x):Q',
y='mean(y):Q',
# dy=alt.value(10),
text=alt.Text('mean(y):Q', format='.2f'),
)
average = alt.Chart(df).mark_rule(color='black', strokeDash=[5, 5]).encode(
y='mean(y):Q',
# size=alt.SizeValue(3),
).transform_filter(
brush
)
# chart = alt.layer(line, average, text)
chart = line
return chart.properties(width=1250, height=500).configure_axis(grid=False, domain=False).configure_view(strokeOpacity=0)
# return line
def max_subarray(arr, k):
n = len(arr)
if (n < k):
st.write('Video too short')
res = 0
left = 0
right = k
for i in range(k):
res += arr[i]
curr_sum = res
for i in range(k, n):
curr_sum += arr[i] - arr[i - k]
if curr_sum > res:
res = curr_sum
left = i - k
right = i
return res, left, right
def edit_video(template, df_all):
video_path = f'videos/{st.session_state.domain.lower()}.mp4'
if template == 'Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)':
res, left, right = max_subarray(df_all['y'].tolist(), 7)
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
fps = video.fps
x_dim = st.session_state.x_dim
y_dim = st.session_state.y_dim
music_path = 'music/coming-in-hot.mp3'
blank1 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.6)
flash1 = video.subclip(t_start=0, t_end=1.2)
blank2 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
flash2 = video.subclip(t_start=1.3, t_end=1.4)
blank3 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
flash3 = video.subclip(t_start=1.5, t_end=3.3)
blank4 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
flash4 = video.subclip(t_start=3.4, t_end=3.5)
blank5 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
flash5 = video.subclip(t_start=3.6, t_end=4.6)
blank6 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
flash6 = video.subclip(t_start=4.7, t_end=4.8)
blank7 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
highlight = video.subclip(t_start=4.9, t_end=6.384)
output = concatenate_videoclips([blank1, flash1, blank2, flash2, blank3, flash3, blank4, flash4, blank5, flash5, blank6, flash6, blank7, highlight])
elif template == 'Thinking Out Loud Cypher by Jermsego (hype, 8 seconds)':
res, left, right = max_subarray(df_all['y'].tolist(), 7)
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
fps = video.fps
x_dim = st.session_state.x_dim
y_dim = st.session_state.y_dim
music_path = 'music/thinking-out-loud.mp3'
blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=1.6)
highlight = video.subclip(t_start=0, t_end=6.852)
output = concatenate_videoclips([blank, highlight])
elif template == 'Sheesh by Surfaces (upbeat, 10 seconds)':
res, left, right = max_subarray(df_all['y'].tolist(), 8)
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
fps = video.fps
x_dim = st.session_state.x_dim
y_dim = st.session_state.y_dim
music_path = 'music/sheesh.mp3'
blank1 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=3.5)
flash1 = video.subclip(t_start=0, t_end=0.1)
blank2 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
flash2 = video.subclip(t_start=0.2, t_end=0.3)
blank3 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
flash3 = video.subclip(t_start=0.4, t_end=0.5)
blank4 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
flash4 = video.subclip(t_start=0.6, t_end=0.7)
blank5 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.9)
highlight = video.subclip(t_start=1.6, t_end=7.18408163265)
output = concatenate_videoclips([blank1, flash1, blank2, flash2, blank3, flash3, blank4, flash4, blank5, highlight])
elif template == 'Moon by Kid Francescoli (tranquil, 10 seconds)':
res, left, right = max_subarray(df_all['y'].tolist(), 9)
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
fps = video.fps
x_dim = st.session_state.x_dim
y_dim = st.session_state.y_dim
music_path = 'music/and-it-went-like.mp3'
blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=1.9)
highlight = video.subclip(t_start=0, t_end=8.132)
output = concatenate_videoclips([blank, highlight])
elif template == 'Ready Set by Joey Valence & Brae (old school, 10 seconds)':
res, left, right = max_subarray(df_all['y'].tolist(), 11)
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
fps = video.fps
x_dim = st.session_state.x_dim
y_dim = st.session_state.y_dim
music_path = 'music/ready-set.mp3'
highlight = video.subclip(t_start=0, t_end=10.512)
output = highlight
elif template == 'Lovewave by The 1-800 (tranquil, 13 seconds)':
res, left, right = max_subarray(df_all['y'].tolist(), 12)
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
fps = video.fps
x_dim = st.session_state.x_dim
y_dim = st.session_state.y_dim
music_path = 'music/lovewave.mp3'
blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=2.1)
highlight = video.subclip(t_start=0, t_end=11.58)
output = concatenate_videoclips([blank, highlight])
elif template == 'And It Sounds Like by Forrest Nolan (tranquil, 17 seconds)':
res, left, right = max_subarray(df_all['y'].tolist(), 16)
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
fps = video.fps
x_dim = st.session_state.x_dim
y_dim = st.session_state.y_dim
music_path = 'music/and-it-sounds-like.mp3'
blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=2)
highlight = video.subclip(t_start=0, t_end=15.928)
output = concatenate_videoclips([blank, highlight])
elif template == 'Comfort Chain by Instupendo (lofi, 18 seconds)':
res, left, right = max_subarray(df_all['y'].tolist(), 19)
video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
fps = video.fps
x_dim = st.session_state.x_dim
y_dim = st.session_state.y_dim
music_path = 'music/comfort-chain.mp3'
highlight = video.subclip(t_start=0, t_end=18.432000000000002)
output = highlight
# st.write(res, left, right)
song = AudioFileClip(music_path)
output = output.set_audio(song)
output.write_videofile('output.mp4', temp_audiofile='temp.m4a', remove_temp=True, audio_codec='aac', logger=None, fps=fps)
st.video('output.mp4')
# return output
def crop_video(df_all, left, right):
video_path = f'videos/{st.session_state.domain.lower()}.mp4'
video = VideoFileClip(video_path)
fps = video.fps
music_path = 'music/loop.mp3'
song = AudioFileClip(music_path)
video = video.set_audio(song)
output = video.subclip(t_start=left, t_end=right)
output.write_videofile('output.mp4', temp_audiofile='temp.m4a', remove_temp=True, audio_codec='aac', logger=None, fps=fps)
st.video('output.mp4')
# return output
st.set_page_config(page_title='Videogenic', page_icon = '✨', layout = 'wide', initial_sidebar_state = 'collapsed')
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
* {font-family: Avenir; cursor: pointer;}
.css-gma2qf {display: flex; justify-content: center; font-size: 42px; font-weight: bold;}
a:link {text-decoration: none;}
a:hover {text-decoration: none;}
.st-ba {font-family: Avenir;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
# clustrmaps = """
# <a href="https://clustrmaps.com/site/1bham" target="_blank" title="Visit tracker"><img src="//www.clustrmaps.com/map_v2.png?d=NhNk5g9hy6Y06nqo7RirhHvZSr89uSS8rPrt471wAXw&cl=ffffff" width="0" height="0"></a>
# """
# st.markdown(clustrmaps, unsafe_allow_html=True)
# ss = SessionState.get(url=None, id=None, input=None, file_name=None, video=None, video_name=None, video_frames=None, video_features=None, fps=None, mode=None, query=None, progress=1)
st.title('Videogenic ✨')
if 'progress' not in st.session_state:
st.session_state.progress = 1
# mode = 'play'
# mode = 'brush'
# mode = 'select'
if st.session_state.progress == 1:
with st.spinner('Loading model...'):
load_model()
domain = st.selectbox('Select video',('Skydiving', 'Surfing', 'Skateboarding')) # Entire journey, montage, vlog
if 'domain' not in st.session_state:
st.session_state.domain = domain
st.session_state.domain = domain
if st.button('Process video'):
with st.spinner('Processing video...'):
video_name = f'videos/{st.session_state.domain.lower()}.mp4'
video_file = open(video_name, 'rb')
video_bytes = video_file.read()
if 'video' not in st.session_state:
st.session_state.video = video_bytes
st.session_state.video = video_bytes
# st.video(st.session_state.video)
video_frames, fps, x_dim, y_dim = video_to_frames(video_name) # first run; video_to_info
np.save(f'files/{st.session_state.domain.lower()}.npy', video_frames)
fps, x_dim, y_dim = video_to_info(video_name)
video_frames = np.load(f'files/{st.session_state.domain.lower()}.npy', allow_pickle=True)
if 'video_frames' not in st.session_state:
st.session_state.video_frames = video_frames
st.session_state.fps = fps
st.session_state.x_dim = x_dim
st.session_state.y_dim = y_dim
st.session_state.video_frames = video_frames
st.session_state.fps = fps
st.session_state.x_dim = x_dim
st.session_state.y_dim = y_dim
print('Extracted frames')
encoded_frames = encode_frames(video_frames) # first run
np.save(f'files/{st.session_state.domain.lower()}_features.npy', encoded_frames)
encoded_frames = np.load(f'files/{st.session_state.domain.lower()}_features.npy', allow_pickle=True)
if 'video_features' not in st.session_state:
# st.session_state.video_features = encoded_frames
st.session_state.video_features = encoded_frames
st.session_state.video_features = encoded_frames
print('Encoded frames')
st.session_state.progress = 2
# with open('activities.txt') as f:
# activities_list = [line.rstrip('\n') for line in f]
# keywords = classify_activity(st.session_state.video_features, activities_list)
# st.write(keywords)
if st.session_state.progress == 2:
mode = st.radio('Select mode', ('Automatic', 'User selection'))
if 'mode' not in st.session_state:
st.session_state.mode = mode
st.session_state.mode = mode
# keywords = list(st.text_input('Enter topic').split(','))
# if st.button('Compute scores') and keywords is not None:
with st.spinner('Computing highlight scores...'):
keyword = st.session_state.domain.lower()
df_list = []
# for keyword in keywords:
img_set = get_photos(keyword)
similarities = compute_scores(img_set, st.session_state.video_features, keyword)
# st.write(similarities)
df = make_df(similarities)
df_list.append(df)
df_all = pd.concat(df_list, ignore_index=True, sort=False)
if 'df_all' not in st.session_state:
st.session_state.df_all = df_all
st.session_state.df_all = df_all
# st.write(df_all)
# highlight_length = 7.033
# st.write(st.session_state.fps)
with st.spinner('Visualizing highlight scores...'):
selection = altair_component(draw_chart(df_all, st.session_state.mode))
print(selection)
if 'selection' not in st.session_state:
st.session_state.selection = selection
st.session_state.selection = selection
# if '_vgsid_' in selection:
# # the ids start at 1
# st.write(df.iloc[[selection['_vgsid_'][0] - 1]])
# else:
# st.info('Hover over the chart above to see details about the Penguin here.')
# if 'x' in selection:
# # the ids start at 1
# st.write(selection['x'])
# chart = draw_chart(df_all, mode)
# st.altair_chart(chart, use_container_width=False)
# st.session_state.progress = 3
# if st.session_state.progress == 3:
if st.session_state.mode == 'Automatic':
# template = st.selectbox('Select template', ['Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)', 'Thinking Out Loud Cypher by Jermsego (hype, 8 seconds)', 'Sheesh by Surfaces (upbeat, 10 seconds)',
# 'Moon by Kid Francescoli (tranquil, 10 seconds)', 'Ready Set by Joey Valence & Brae (old school, 10 seconds)', 'Lovewave by The 1-800 (tranquil, 13 seconds)',
# 'And It Sounds Like by Forrest Nolan (tranquil, 17 seconds)', 'Comfort Chain by Instupendo (lofi, 18 seconds)'])
template = st.selectbox('Select template', ['Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)', 'Sheesh by Surfaces (upbeat, 10 seconds)', 'Lovewave by The 1-800 (tranquil, 13 seconds)'])
if st.button('Generate video'):
with st.spinner('Generating highlight video...'):
edit_video(template, st.session_state.df_all)
st.balloons()
elif st.session_state.mode == 'User selection':
if st.button('Generate video'):
left = st.session_state.selection['x'][0]
right = st.session_state.selection['x'][1]
with st.spinner('Generating highlight video...'):
crop_video(st.session_state.df_all, left, right)
st.balloons()