import subprocess
import os
import glob
import streamlit as st
import streamlit.components.v1 as components
import base64
# Function to find the latest file with a given extension in a specified directory
def find_latest_file(base_path, extension):
list_of_files = glob.glob(f'{base_path}/*.{extension}')
if not list_of_files:
return None
latest_file = max(list_of_files, key=os.path.getctime)
return latest_file
# Function to run the generate.py script and return paths of generated HTML and NPY files
def generate_html(text_input, length):
command = [
"python", "generate.py",
"--resume-pth", "output/vq/2023-07-19-04-17-17_12_VQVAE_20batchResetNRandom_8192_32/net_last.pth",
"--resume-trans", "output/t2m/2023-10-10-03-17-01_HML3D_44_crsAtt2lyr_mask0.5-1/net_last.pth",
"--text", text_input,
"--length", str(length)
]
try:
result = subprocess.run(command, check=True, text=True, capture_output=True)
html_file = find_latest_file('output', 'html')
npy_file = find_latest_file('output', 'npy')
return html_file, npy_file
except subprocess.CalledProcessError as e:
st.error(f"Error: {e.stderr}")
return None, None
# Function to run render_final.py script with the generated NPY file
def run_render_final(npy_file_path):
command = ["python", "render_final.py", npy_file_path]
try:
gif_res = subprocess.run(command, check=True, text=True, capture_output=True)
vid_file_path = find_latest_file('output', 'mp4')
return vid_file_path
except subprocess.CalledProcessError as e:
st.error(f"Error: {e.stderr}")
return None
# Function to convert GIF to base64
def gif_to_base64(gif_file_path):
with open(gif_file_path, "rb") as gif_file:
gif_bytes = gif_file.read()
base64_gif = base64.b64encode(gif_bytes).decode("utf-8")
return base64_gif
# Initialize session state
if 'text_input' not in st.session_state:
st.session_state.text_input = ""
if 'length' not in st.session_state:
st.session_state.length = 156
# Handler to update session state and rerun the app
def select_prompt(prompt, prompt_length):
st.session_state.text_input = prompt
st.session_state.length = prompt_length
# App layout
components.html("
MMM Model Demo
", height=100)
prompts = [
("A person walks forward then turns completely around and does a cartwheel", 196),
("A person bouncing around while throwing jabs and upper cuts.", 196),
("A person start to dance with legs", 176),
("A person steps forward and leans over; they grab a cup with their left hand and empty it before putting it down and stepping back to their original position.", 156),
("Walking forward and kicking foot.", 68),
("A man walks forward, stumbles to the right, and then regains his balance and keeps walking forwards.", 92)
]
col1, col2 = st.columns([6, 5])
# Placeholder for the input fields
input_placeholder = st.empty()
with col1:
input_placeholder = st.empty()
with col2:
st.write("Or choose a prompt:")
for prompt, prompt_length in prompts:
if st.button(prompt):
select_prompt(prompt, prompt_length)
# Render the input fields inside the placeholder
with input_placeholder.container():
text_input = st.text_area("Enter text here:", value=st.session_state.text_input, key="text_input", height=300)
length = st.number_input("Length of the generated motion:", value=st.session_state.length, key="length")
# Place the buttons side by side
button_col1, button_col2 = st.columns(2)
with button_col1:
if st.button("Generate HTML"):
if st.session_state.text_input and st.session_state.length:
html_file_path, npy_file_path = generate_html(st.session_state.text_input, st.session_state.length)
if html_file_path and npy_file_path:
st.session_state.html_file_path = html_file_path
st.session_state.npy_file_path = npy_file_path
# Display the HTML file content
with open(html_file_path, 'r') as file:
html_content = file.read()
st.session_state.html_content = html_content
else:
st.error("Error generating files. Please try again.")
with button_col2:
if st.button("Render Skeleton"):
if 'npy_file_path' in st.session_state and st.session_state.npy_file_path:
vid_file_path = run_render_final(st.session_state.npy_file_path)
if vid_file_path:
st.session_state.vid_file_path = vid_file_path
# st.session_state.gif_base64 = gif_to_base64(gif_file_path)
else:
st.error("No npy file found. Please generate HTML first.")
# Display the results side by side using HTML components
if 'html_content' in st.session_state or 'gif_base64' in st.session_state:
html_content = st.session_state.html_content if 'html_content' in st.session_state else ""
video_path = st.session_state.vid_file_path if 'vid_file_path' in st.session_state else ""
disp_col1, disp_col2 = st.columns([1, 1])
with disp_col1:
components.html(html_content, height=800, scrolling=True)
with disp_col2:
if video_path:
video_file = open(video_path, 'rb')
video_bytes = video_file.read()
st.video(video_bytes, format='video/mp4', loop=True)