feature: added rendering
Browse files- app.py +78 -26
- generate.py +12 -5
- render_final.py +1 -2
- requirements.txt +67 -3
app.py
CHANGED
@@ -1,10 +1,19 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import streamlit.components.v1 as components
|
3 |
import subprocess
|
4 |
import os
|
5 |
import glob
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
# Function to run the
|
8 |
def generate_html(text_input, length):
|
9 |
command = [
|
10 |
"python", "generate.py",
|
@@ -15,18 +24,30 @@ def generate_html(text_input, length):
|
|
15 |
]
|
16 |
try:
|
17 |
result = subprocess.run(command, check=True, text=True, capture_output=True)
|
18 |
-
|
|
|
|
|
19 |
except subprocess.CalledProcessError as e:
|
20 |
-
|
21 |
-
return None
|
22 |
|
23 |
-
# Function to
|
24 |
-
def
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
27 |
return None
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
# Initialize session state
|
32 |
if 'text_input' not in st.session_state:
|
@@ -34,16 +55,14 @@ if 'text_input' not in st.session_state:
|
|
34 |
if 'length' not in st.session_state:
|
35 |
st.session_state.length = 156
|
36 |
|
37 |
-
#
|
38 |
def select_prompt(prompt, prompt_length):
|
39 |
st.session_state.text_input = prompt
|
40 |
st.session_state.length = prompt_length
|
41 |
|
42 |
-
# app layout
|
43 |
-
components.html("""
|
44 |
-
<h1 style='text-align: center; color: white;'>MMM Model Demo</h1>
|
45 |
-
""", height=100)
|
46 |
|
|
|
|
|
47 |
|
48 |
prompts = [
|
49 |
("A person walks forward then turns completely around and does a cartwheel", 196),
|
@@ -73,13 +92,46 @@ with input_placeholder.container():
|
|
73 |
text_input = st.text_area("Enter text here:", value=st.session_state.text_input, key="text_input", height=300)
|
74 |
length = st.number_input("Length of the generated text:", value=st.session_state.length, key="length")
|
75 |
|
76 |
-
#
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
else:
|
85 |
-
st.error("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import subprocess
|
2 |
import os
|
3 |
import glob
|
4 |
+
import streamlit as st
|
5 |
+
import streamlit.components.v1 as components
|
6 |
+
import base64
|
7 |
+
|
8 |
+
# Function to find the latest file with a given extension in a specified directory
|
9 |
+
def find_latest_file(base_path, extension):
|
10 |
+
list_of_files = glob.glob(f'{base_path}/*.{extension}')
|
11 |
+
if not list_of_files:
|
12 |
+
return None
|
13 |
+
latest_file = max(list_of_files, key=os.path.getctime)
|
14 |
+
return latest_file
|
15 |
|
16 |
+
# Function to run the generate.py script and return paths of generated HTML and NPY files
|
17 |
def generate_html(text_input, length):
|
18 |
command = [
|
19 |
"python", "generate.py",
|
|
|
24 |
]
|
25 |
try:
|
26 |
result = subprocess.run(command, check=True, text=True, capture_output=True)
|
27 |
+
html_file = find_latest_file('output', 'html')
|
28 |
+
npy_file = find_latest_file('output', 'npy')
|
29 |
+
return html_file, npy_file
|
30 |
except subprocess.CalledProcessError as e:
|
31 |
+
st.error(f"Error: {e.stderr}")
|
32 |
+
return None, None
|
33 |
|
34 |
+
# Function to run render_final.py script with the generated NPY file
|
35 |
+
def run_render_final(npy_file_path):
|
36 |
+
command = ["python", "render_final.py", npy_file_path]
|
37 |
+
try:
|
38 |
+
gif_res = subprocess.run(command, check=True, text=True, capture_output=True)
|
39 |
+
gif_file_path = find_latest_file('output', 'gif')
|
40 |
+
return gif_file_path
|
41 |
+
except subprocess.CalledProcessError as e:
|
42 |
+
st.error(f"Error: {e.stderr}")
|
43 |
return None
|
44 |
+
|
45 |
+
# Function to convert GIF to base64
|
46 |
+
def gif_to_base64(gif_file_path):
|
47 |
+
with open(gif_file_path, "rb") as gif_file:
|
48 |
+
gif_bytes = gif_file.read()
|
49 |
+
base64_gif = base64.b64encode(gif_bytes).decode("utf-8")
|
50 |
+
return base64_gif
|
51 |
|
52 |
# Initialize session state
|
53 |
if 'text_input' not in st.session_state:
|
|
|
55 |
if 'length' not in st.session_state:
|
56 |
st.session_state.length = 156
|
57 |
|
58 |
+
# Handler to update session state and rerun the app
|
59 |
def select_prompt(prompt, prompt_length):
|
60 |
st.session_state.text_input = prompt
|
61 |
st.session_state.length = prompt_length
|
62 |
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
# App layout
|
65 |
+
components.html("<h1 style='text-align: center; color: white;'>MMM Model Demo</h1>", height=100)
|
66 |
|
67 |
prompts = [
|
68 |
("A person walks forward then turns completely around and does a cartwheel", 196),
|
|
|
92 |
text_input = st.text_area("Enter text here:", value=st.session_state.text_input, key="text_input", height=300)
|
93 |
length = st.number_input("Length of the generated text:", value=st.session_state.length, key="length")
|
94 |
|
95 |
+
# Place the buttons side by side
|
96 |
+
button_col1, button_col2 = st.columns(2)
|
97 |
+
|
98 |
+
with button_col1:
|
99 |
+
if st.button("Generate HTML"):
|
100 |
+
if st.session_state.text_input and st.session_state.length:
|
101 |
+
html_file_path, npy_file_path = generate_html(st.session_state.text_input, st.session_state.length)
|
102 |
+
if html_file_path and npy_file_path:
|
103 |
+
st.session_state.html_file_path = html_file_path
|
104 |
+
st.session_state.npy_file_path = npy_file_path
|
105 |
+
|
106 |
+
# Display the HTML file content
|
107 |
+
with open(html_file_path, 'r') as file:
|
108 |
+
html_content = file.read()
|
109 |
+
st.session_state.html_content = html_content
|
110 |
+
else:
|
111 |
+
st.error("Error generating files. Please try again.")
|
112 |
+
|
113 |
+
with button_col2:
|
114 |
+
if st.button("Render Skeleton"):
|
115 |
+
if 'npy_file_path' in st.session_state and st.session_state.npy_file_path:
|
116 |
+
gif_file_path = run_render_final(st.session_state.npy_file_path)
|
117 |
+
if gif_file_path:
|
118 |
+
st.session_state.gif_file_path = gif_file_path
|
119 |
+
st.session_state.gif_base64 = gif_to_base64(gif_file_path)
|
120 |
else:
|
121 |
+
st.error("No npy file found. Please generate HTML first.")
|
122 |
+
|
123 |
+
|
124 |
+
# Display the results side by side using HTML components
|
125 |
+
if 'html_content' in st.session_state or 'gif_base64' in st.session_state:
|
126 |
+
html_content = st.session_state.html_content if 'html_content' in st.session_state else ""
|
127 |
+
gif_base64 = st.session_state.gif_base64 if 'gif_base64' in st.session_state else ""
|
128 |
+
|
129 |
+
disp_col1, disp_col2 = st.columns([1, 1])
|
130 |
+
|
131 |
+
with disp_col1:
|
132 |
+
components.html(html_content, height=800, scrolling=True)
|
133 |
+
|
134 |
+
with disp_col2:
|
135 |
+
if gif_base64:
|
136 |
+
gif_html = f'<img src="data:image/gif;base64,{gif_base64}" style="width:100%;">'
|
137 |
+
components.html(gif_html, height=800, scrolling=True)
|
generate.py
CHANGED
@@ -291,15 +291,22 @@ if __name__ == '__main__':
|
|
291 |
pred_pose = mmm([args.text], torch.tensor([args.length]), rand_pos=False)
|
292 |
num_joints = 22
|
293 |
|
294 |
-
|
|
|
|
|
|
|
|
|
295 |
|
296 |
-
|
|
|
297 |
|
298 |
-
|
|
|
|
|
|
|
|
|
299 |
print('File saved successfully')
|
300 |
|
301 |
-
std = np.load('./exit/t2m-std.npy')
|
302 |
-
mean = np.load('./exit/t2m-mean.npy')
|
303 |
file_name = '_'.join(args.text.split(' '))+'_'+str(args.length)
|
304 |
visualize_2motions(pred_pose[0].detach().cpu().numpy(), std, mean, 't2m', args.length, save_path='./output/'+file_name+'.html')
|
305 |
|
|
|
291 |
pred_pose = mmm([args.text], torch.tensor([args.length]), rand_pos=False)
|
292 |
num_joints = 22
|
293 |
|
294 |
+
std = np.load('./exit/t2m-std.npy')
|
295 |
+
mean = np.load('./exit/t2m-mean.npy')
|
296 |
+
|
297 |
+
norm_pose = pred_pose[0].detach().cpu().numpy() * std + mean
|
298 |
+
norm_pose = torch.tensor(norm_pose)
|
299 |
|
300 |
+
trimmed_pose = norm_pose[:args.length, :].unsqueeze(0).float()
|
301 |
+
print(trimmed_pose.shape)
|
302 |
|
303 |
+
converted_pose = recover_from_ric(trimmed_pose[0].detach().cpu(), num_joints).unsqueeze(0).numpy()
|
304 |
+
print(converted_pose.shape)
|
305 |
+
|
306 |
+
filename = '_'.join(args.text.split(' '))+'_'+str(args.length)
|
307 |
+
np.save('./output/'+filename+'.npy', converted_pose)
|
308 |
print('File saved successfully')
|
309 |
|
|
|
|
|
310 |
file_name = '_'.join(args.text.split(' '))+'_'+str(args.length)
|
311 |
visualize_2motions(pred_pose[0].detach().cpu().numpy(), std, mean, 't2m', args.length, save_path='./output/'+file_name+'.html')
|
312 |
|
render_final.py
CHANGED
@@ -164,8 +164,7 @@ def render(motions, outdir='test_vis', device_id=0, name=None, pred=True):
|
|
164 |
gif_path = os.path.join(outdir, f'{name}.gif')
|
165 |
imageio.mimsave(gif_path, out, fps=20)
|
166 |
|
167 |
-
|
168 |
-
|
169 |
|
170 |
|
171 |
if __name__ == "__main__":
|
|
|
164 |
gif_path = os.path.join(outdir, f'{name}.gif')
|
165 |
imageio.mimsave(gif_path, out, fps=20)
|
166 |
|
167 |
+
|
|
|
168 |
|
169 |
|
170 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
@@ -1,37 +1,101 @@
|
|
1 |
-
|
2 |
-
|
3 |
beautifulsoup4==4.12.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
einops==0.8.0
|
5 |
fastjsonschema==2.20.0
|
|
|
|
|
6 |
freetype-py==2.4.0
|
7 |
fsspec==2024.6.1
|
8 |
ftfy==6.2.0
|
9 |
gdown==5.2.0
|
|
|
|
|
10 |
h5py==3.11.0
|
|
|
11 |
imageio==2.34.2
|
|
|
|
|
12 |
jsonschema==4.22.0
|
13 |
jsonschema-specifications==2023.12.1
|
14 |
jupyter_core==5.7.2
|
|
|
15 |
mapbox-earcut==1.0.1
|
|
|
|
|
|
|
|
|
|
|
16 |
nbformat==5.10.4
|
|
|
17 |
numpy==1.23.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
pkgutil_resolve_name==1.3.10
|
19 |
-
|
|
|
20 |
protobuf==3.20.3
|
|
|
|
|
21 |
pyasn1-modules==0.2.8
|
|
|
22 |
pyglet==2.0.15
|
|
|
|
|
23 |
pyrender==0.1.45
|
|
|
|
|
|
|
24 |
referencing==0.35.1
|
25 |
regex==2024.5.15
|
|
|
26 |
requests-oauthlib==1.3.0
|
|
|
27 |
rpds-py==0.18.1
|
28 |
scipy==1.10.1
|
29 |
shapely==2.0.4
|
|
|
|
|
|
|
30 |
soupsieve==2.5
|
|
|
|
|
|
|
|
|
|
|
31 |
torch==2.3.1
|
32 |
torchaudio==2.3.1
|
33 |
torchvision==0.18.1
|
|
|
34 |
tqdm==4.66.4
|
35 |
traitlets==5.14.3
|
36 |
trimesh==4.4.1
|
|
|
|
|
|
|
|
|
|
|
37 |
wcwidth==0.2.13
|
|
|
|
1 |
+
altair==5.3.0
|
2 |
+
attrs==23.2.0
|
3 |
beautifulsoup4==4.12.3
|
4 |
+
blinker==1.8.2
|
5 |
+
cachetools==5.3.3
|
6 |
+
certifi==2024.6.2
|
7 |
+
charset-normalizer==3.3.2
|
8 |
+
chumpy==0.70
|
9 |
+
click==8.1.7
|
10 |
+
contourpy==1.1.1
|
11 |
+
cycler==0.12.1
|
12 |
einops==0.8.0
|
13 |
fastjsonschema==2.20.0
|
14 |
+
filelock==3.15.4
|
15 |
+
fonttools==4.53.0
|
16 |
freetype-py==2.4.0
|
17 |
fsspec==2024.6.1
|
18 |
ftfy==6.2.0
|
19 |
gdown==5.2.0
|
20 |
+
gitdb==4.0.11
|
21 |
+
GitPython==3.1.43
|
22 |
h5py==3.11.0
|
23 |
+
idna==3.7
|
24 |
imageio==2.34.2
|
25 |
+
importlib_resources==6.4.0
|
26 |
+
Jinja2==3.1.4
|
27 |
jsonschema==4.22.0
|
28 |
jsonschema-specifications==2023.12.1
|
29 |
jupyter_core==5.7.2
|
30 |
+
kiwisolver==1.4.5
|
31 |
mapbox-earcut==1.0.1
|
32 |
+
markdown-it-py==3.0.0
|
33 |
+
MarkupSafe==2.1.5
|
34 |
+
matplotlib==3.7.5
|
35 |
+
mdurl==0.1.2
|
36 |
+
mpmath==1.3.0
|
37 |
nbformat==5.10.4
|
38 |
+
networkx==3.1
|
39 |
numpy==1.23.3
|
40 |
+
nvidia-cublas-cu12==12.1.3.1
|
41 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
42 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
43 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
44 |
+
nvidia-cudnn-cu12==8.9.2.26
|
45 |
+
nvidia-cufft-cu12==11.0.2.54
|
46 |
+
nvidia-curand-cu12==10.3.2.106
|
47 |
+
nvidia-cusolver-cu12==11.4.5.107
|
48 |
+
nvidia-cusparse-cu12==12.1.0.106
|
49 |
+
nvidia-nccl-cu12==2.20.5
|
50 |
+
nvidia-nvjitlink-cu12==12.5.40
|
51 |
+
nvidia-nvtx-cu12==12.1.105
|
52 |
+
oauthlib==3.2.2
|
53 |
+
packaging==24.1
|
54 |
+
pandas==2.0.3
|
55 |
+
pillow==10.3.0
|
56 |
pkgutil_resolve_name==1.3.10
|
57 |
+
platformdirs==4.2.2
|
58 |
+
plotly==5.22.0
|
59 |
protobuf==3.20.3
|
60 |
+
pyarrow==16.1.0
|
61 |
+
pyasn1==0.4.8
|
62 |
pyasn1-modules==0.2.8
|
63 |
+
pydeck==0.9.1
|
64 |
pyglet==2.0.15
|
65 |
+
Pygments==2.18.0
|
66 |
+
pyparsing==3.1.2
|
67 |
pyrender==0.1.45
|
68 |
+
PySocks==1.7.1
|
69 |
+
python-dateutil==2.9.0.post0
|
70 |
+
pytz==2024.1
|
71 |
referencing==0.35.1
|
72 |
regex==2024.5.15
|
73 |
+
requests==2.32.3
|
74 |
requests-oauthlib==1.3.0
|
75 |
+
rich==13.7.1
|
76 |
rpds-py==0.18.1
|
77 |
scipy==1.10.1
|
78 |
shapely==2.0.4
|
79 |
+
six==1.16.0
|
80 |
+
smmap==5.0.1
|
81 |
+
smplx==0.1.28
|
82 |
soupsieve==2.5
|
83 |
+
streamlit==1.36.0
|
84 |
+
sympy==1.12.1
|
85 |
+
tenacity==8.4.2
|
86 |
+
toml==0.10.2
|
87 |
+
toolz==0.12.1
|
88 |
torch==2.3.1
|
89 |
torchaudio==2.3.1
|
90 |
torchvision==0.18.1
|
91 |
+
tornado==6.4.1
|
92 |
tqdm==4.66.4
|
93 |
traitlets==5.14.3
|
94 |
trimesh==4.4.1
|
95 |
+
triton==2.3.1
|
96 |
+
typing_extensions==4.12.2
|
97 |
+
tzdata==2024.1
|
98 |
+
urllib3==2.2.2
|
99 |
+
watchdog==4.0.1
|
100 |
wcwidth==0.2.13
|
101 |
+
zipp==3.19.2
|