import subprocess import os import glob import streamlit as st import streamlit.components.v1 as components import base64 # Function to find the latest file with a given extension in a specified directory def find_latest_file(base_path, extension): list_of_files = glob.glob(f'{base_path}/*.{extension}') if not list_of_files: return None latest_file = max(list_of_files, key=os.path.getctime) return latest_file # Function to run the generate.py script and return paths of generated HTML and NPY files def generate_html(text_input, length): command = [ "python", "generate.py", "--resume-pth", "output/vq/2023-07-19-04-17-17_12_VQVAE_20batchResetNRandom_8192_32/net_last.pth", "--resume-trans", "output/t2m/2023-10-10-03-17-01_HML3D_44_crsAtt2lyr_mask0.5-1/net_last.pth", "--text", text_input, "--length", str(length) ] try: result = subprocess.run(command, check=True, text=True, capture_output=True) html_file = find_latest_file('output', 'html') npy_file = find_latest_file('output', 'npy') return html_file, npy_file except subprocess.CalledProcessError as e: st.error(f"Error: {e.stderr}") return None, None # Function to run render_final.py script with the generated NPY file def run_render_final(npy_file_path): command = ["python", "render_final.py", npy_file_path] try: gif_res = subprocess.run(command, check=True, text=True, capture_output=True) vid_file_path = find_latest_file('output', 'mp4') return vid_file_path except subprocess.CalledProcessError as e: st.error(f"Error: {e.stderr}") return None # Function to convert GIF to base64 def gif_to_base64(gif_file_path): with open(gif_file_path, "rb") as gif_file: gif_bytes = gif_file.read() base64_gif = base64.b64encode(gif_bytes).decode("utf-8") return base64_gif # Initialize session state if 'text_input' not in st.session_state: st.session_state.text_input = "" if 'length' not in st.session_state: st.session_state.length = 156 # Handler to update session state and rerun the app def select_prompt(prompt, prompt_length): st.session_state.text_input = prompt st.session_state.length = prompt_length # App layout components.html("

MMM Model Demo

", height=100) prompts = [ ("A person walks forward then turns completely around and does a cartwheel", 196), ("A person bouncing around while throwing jabs and upper cuts.", 196), ("A person start to dance with legs", 176), ("A person steps forward and leans over; they grab a cup with their left hand and empty it before putting it down and stepping back to their original position.", 156), ("Walking forward and kicking foot.", 68), ("A man walks forward, stumbles to the right, and then regains his balance and keeps walking forwards.", 92) ] col1, col2 = st.columns([6, 5]) # Placeholder for the input fields input_placeholder = st.empty() with col1: input_placeholder = st.empty() with col2: st.write("Or choose a prompt:") for prompt, prompt_length in prompts: if st.button(prompt): select_prompt(prompt, prompt_length) # Render the input fields inside the placeholder with input_placeholder.container(): text_input = st.text_area("Enter text here:", value=st.session_state.text_input, key="text_input", height=300) length = st.number_input("Length of the generated motion:", value=st.session_state.length, key="length") # Place the buttons side by side button_col1, button_col2 = st.columns(2) with button_col1: if st.button("Generate HTML"): if st.session_state.text_input and st.session_state.length: html_file_path, npy_file_path = generate_html(st.session_state.text_input, st.session_state.length) if html_file_path and npy_file_path: st.session_state.html_file_path = html_file_path st.session_state.npy_file_path = npy_file_path # Display the HTML file content with open(html_file_path, 'r') as file: html_content = file.read() st.session_state.html_content = html_content else: st.error("Error generating files. Please try again.") with button_col2: if st.button("Render Skeleton"): if 'npy_file_path' in st.session_state and st.session_state.npy_file_path: vid_file_path = run_render_final(st.session_state.npy_file_path) if vid_file_path: st.session_state.vid_file_path = vid_file_path # st.session_state.gif_base64 = gif_to_base64(gif_file_path) else: st.error("No npy file found. Please generate HTML first.") # Display the results side by side using HTML components if 'html_content' in st.session_state or 'gif_base64' in st.session_state: html_content = st.session_state.html_content if 'html_content' in st.session_state else "" video_path = st.session_state.vid_file_path if 'vid_file_path' in st.session_state else "" disp_col1, disp_col2 = st.columns([1, 1]) with disp_col1: components.html(html_content, height=800, scrolling=True) with disp_col2: if video_path: video_file = open(video_path, 'rb') video_bytes = video_file.read() st.video(video_bytes, format='video/mp4', loop=True)