samadi10 commited on
Commit
483442c
1 Parent(s): d63dc03

feature: added rendering

Browse files
Files changed (4) hide show
  1. app.py +78 -26
  2. generate.py +12 -5
  3. render_final.py +1 -2
  4. requirements.txt +67 -3
app.py CHANGED
@@ -1,10 +1,19 @@
1
- import streamlit as st
2
- import streamlit.components.v1 as components
3
  import subprocess
4
  import os
5
  import glob
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # Function to run the model and generate HTML
8
  def generate_html(text_input, length):
9
  command = [
10
  "python", "generate.py",
@@ -15,18 +24,30 @@ def generate_html(text_input, length):
15
  ]
16
  try:
17
  result = subprocess.run(command, check=True, text=True, capture_output=True)
18
- return find_latest_html_file('output')
 
 
19
  except subprocess.CalledProcessError as e:
20
- print("Error:", e.stderr)
21
- return None
22
 
23
- # Function to find the latest HTML file in the specified directory
24
- def find_latest_html_file(base_path):
25
- list_of_files = glob.glob(f'{base_path}/*.html')
26
- if not list_of_files:
 
 
 
 
 
27
  return None
28
- latest_file = max(list_of_files, key=os.path.getctime)
29
- return latest_file
 
 
 
 
 
30
 
31
  # Initialize session state
32
  if 'text_input' not in st.session_state:
@@ -34,16 +55,14 @@ if 'text_input' not in st.session_state:
34
  if 'length' not in st.session_state:
35
  st.session_state.length = 156
36
 
37
- # handler to update session state and rerun the app
38
  def select_prompt(prompt, prompt_length):
39
  st.session_state.text_input = prompt
40
  st.session_state.length = prompt_length
41
 
42
- # app layout
43
- components.html("""
44
- <h1 style='text-align: center; color: white;'>MMM Model Demo</h1>
45
- """, height=100)
46
 
 
 
47
 
48
  prompts = [
49
  ("A person walks forward then turns completely around and does a cartwheel", 196),
@@ -73,13 +92,46 @@ with input_placeholder.container():
73
  text_input = st.text_area("Enter text here:", value=st.session_state.text_input, key="text_input", height=300)
74
  length = st.number_input("Length of the generated text:", value=st.session_state.length, key="length")
75
 
76
- # Button trigger to generate HTML
77
- if st.button("Generate HTML"):
78
- if st.session_state.text_input and st.session_state.length:
79
- html_file_path = generate_html(st.session_state.text_input, st.session_state.length)
80
- if html_file_path and os.path.exists(html_file_path):
81
- with open(html_file_path, 'r') as file:
82
- html_content = file.read()
83
- components.html(html_content, height=800, width=800)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  else:
85
- st.error("Error generating HTML file. Please try again.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import subprocess
2
  import os
3
  import glob
4
+ import streamlit as st
5
+ import streamlit.components.v1 as components
6
+ import base64
7
+
8
+ # Function to find the latest file with a given extension in a specified directory
9
+ def find_latest_file(base_path, extension):
10
+ list_of_files = glob.glob(f'{base_path}/*.{extension}')
11
+ if not list_of_files:
12
+ return None
13
+ latest_file = max(list_of_files, key=os.path.getctime)
14
+ return latest_file
15
 
16
+ # Function to run the generate.py script and return paths of generated HTML and NPY files
17
  def generate_html(text_input, length):
18
  command = [
19
  "python", "generate.py",
 
24
  ]
25
  try:
26
  result = subprocess.run(command, check=True, text=True, capture_output=True)
27
+ html_file = find_latest_file('output', 'html')
28
+ npy_file = find_latest_file('output', 'npy')
29
+ return html_file, npy_file
30
  except subprocess.CalledProcessError as e:
31
+ st.error(f"Error: {e.stderr}")
32
+ return None, None
33
 
34
+ # Function to run render_final.py script with the generated NPY file
35
+ def run_render_final(npy_file_path):
36
+ command = ["python", "render_final.py", npy_file_path]
37
+ try:
38
+ gif_res = subprocess.run(command, check=True, text=True, capture_output=True)
39
+ gif_file_path = find_latest_file('output', 'gif')
40
+ return gif_file_path
41
+ except subprocess.CalledProcessError as e:
42
+ st.error(f"Error: {e.stderr}")
43
  return None
44
+
45
+ # Function to convert GIF to base64
46
+ def gif_to_base64(gif_file_path):
47
+ with open(gif_file_path, "rb") as gif_file:
48
+ gif_bytes = gif_file.read()
49
+ base64_gif = base64.b64encode(gif_bytes).decode("utf-8")
50
+ return base64_gif
51
 
52
  # Initialize session state
53
  if 'text_input' not in st.session_state:
 
55
  if 'length' not in st.session_state:
56
  st.session_state.length = 156
57
 
58
+ # Handler to update session state and rerun the app
59
  def select_prompt(prompt, prompt_length):
60
  st.session_state.text_input = prompt
61
  st.session_state.length = prompt_length
62
 
 
 
 
 
63
 
64
+ # App layout
65
+ components.html("<h1 style='text-align: center; color: white;'>MMM Model Demo</h1>", height=100)
66
 
67
  prompts = [
68
  ("A person walks forward then turns completely around and does a cartwheel", 196),
 
92
  text_input = st.text_area("Enter text here:", value=st.session_state.text_input, key="text_input", height=300)
93
  length = st.number_input("Length of the generated text:", value=st.session_state.length, key="length")
94
 
95
+ # Place the buttons side by side
96
+ button_col1, button_col2 = st.columns(2)
97
+
98
+ with button_col1:
99
+ if st.button("Generate HTML"):
100
+ if st.session_state.text_input and st.session_state.length:
101
+ html_file_path, npy_file_path = generate_html(st.session_state.text_input, st.session_state.length)
102
+ if html_file_path and npy_file_path:
103
+ st.session_state.html_file_path = html_file_path
104
+ st.session_state.npy_file_path = npy_file_path
105
+
106
+ # Display the HTML file content
107
+ with open(html_file_path, 'r') as file:
108
+ html_content = file.read()
109
+ st.session_state.html_content = html_content
110
+ else:
111
+ st.error("Error generating files. Please try again.")
112
+
113
+ with button_col2:
114
+ if st.button("Render Skeleton"):
115
+ if 'npy_file_path' in st.session_state and st.session_state.npy_file_path:
116
+ gif_file_path = run_render_final(st.session_state.npy_file_path)
117
+ if gif_file_path:
118
+ st.session_state.gif_file_path = gif_file_path
119
+ st.session_state.gif_base64 = gif_to_base64(gif_file_path)
120
  else:
121
+ st.error("No npy file found. Please generate HTML first.")
122
+
123
+
124
+ # Display the results side by side using HTML components
125
+ if 'html_content' in st.session_state or 'gif_base64' in st.session_state:
126
+ html_content = st.session_state.html_content if 'html_content' in st.session_state else ""
127
+ gif_base64 = st.session_state.gif_base64 if 'gif_base64' in st.session_state else ""
128
+
129
+ disp_col1, disp_col2 = st.columns([1, 1])
130
+
131
+ with disp_col1:
132
+ components.html(html_content, height=800, scrolling=True)
133
+
134
+ with disp_col2:
135
+ if gif_base64:
136
+ gif_html = f'<img src="data:image/gif;base64,{gif_base64}" style="width:100%;">'
137
+ components.html(gif_html, height=800, scrolling=True)
generate.py CHANGED
@@ -291,15 +291,22 @@ if __name__ == '__main__':
291
  pred_pose = mmm([args.text], torch.tensor([args.length]), rand_pos=False)
292
  num_joints = 22
293
 
294
- pred_pose = pred_pose[:args.length, :].detach().cpu()
 
 
 
 
295
 
296
- converted_pose = recover_from_ric(pred_pose[0].detach().cpu(), num_joints).unsqueeze(0).numpy()
 
297
 
298
- np.save('./output/mmm-pred.npy', converted_pose)
 
 
 
 
299
  print('File saved successfully')
300
 
301
- std = np.load('./exit/t2m-std.npy')
302
- mean = np.load('./exit/t2m-mean.npy')
303
  file_name = '_'.join(args.text.split(' '))+'_'+str(args.length)
304
  visualize_2motions(pred_pose[0].detach().cpu().numpy(), std, mean, 't2m', args.length, save_path='./output/'+file_name+'.html')
305
 
 
291
  pred_pose = mmm([args.text], torch.tensor([args.length]), rand_pos=False)
292
  num_joints = 22
293
 
294
+ std = np.load('./exit/t2m-std.npy')
295
+ mean = np.load('./exit/t2m-mean.npy')
296
+
297
+ norm_pose = pred_pose[0].detach().cpu().numpy() * std + mean
298
+ norm_pose = torch.tensor(norm_pose)
299
 
300
+ trimmed_pose = norm_pose[:args.length, :].unsqueeze(0).float()
301
+ print(trimmed_pose.shape)
302
 
303
+ converted_pose = recover_from_ric(trimmed_pose[0].detach().cpu(), num_joints).unsqueeze(0).numpy()
304
+ print(converted_pose.shape)
305
+
306
+ filename = '_'.join(args.text.split(' '))+'_'+str(args.length)
307
+ np.save('./output/'+filename+'.npy', converted_pose)
308
  print('File saved successfully')
309
 
 
 
310
  file_name = '_'.join(args.text.split(' '))+'_'+str(args.length)
311
  visualize_2motions(pred_pose[0].detach().cpu().numpy(), std, mean, 't2m', args.length, save_path='./output/'+file_name+'.html')
312
 
render_final.py CHANGED
@@ -164,8 +164,7 @@ def render(motions, outdir='test_vis', device_id=0, name=None, pred=True):
164
  gif_path = os.path.join(outdir, f'{name}.gif')
165
  imageio.mimsave(gif_path, out, fps=20)
166
 
167
-
168
-
169
 
170
 
171
  if __name__ == "__main__":
 
164
  gif_path = os.path.join(outdir, f'{name}.gif')
165
  imageio.mimsave(gif_path, out, fps=20)
166
 
167
+
 
168
 
169
 
170
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,37 +1,101 @@
1
- git+https://github.com/openai/CLIP.git
2
- git+https://github.com/mmatl/pyopengl.git
3
  beautifulsoup4==4.12.3
 
 
 
 
 
 
 
 
4
  einops==0.8.0
5
  fastjsonschema==2.20.0
 
 
6
  freetype-py==2.4.0
7
  fsspec==2024.6.1
8
  ftfy==6.2.0
9
  gdown==5.2.0
 
 
10
  h5py==3.11.0
 
11
  imageio==2.34.2
 
 
12
  jsonschema==4.22.0
13
  jsonschema-specifications==2023.12.1
14
  jupyter_core==5.7.2
 
15
  mapbox-earcut==1.0.1
 
 
 
 
 
16
  nbformat==5.10.4
 
17
  numpy==1.23.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  pkgutil_resolve_name==1.3.10
19
- plotly
 
20
  protobuf==3.20.3
 
 
21
  pyasn1-modules==0.2.8
 
22
  pyglet==2.0.15
 
 
23
  pyrender==0.1.45
 
 
 
24
  referencing==0.35.1
25
  regex==2024.5.15
 
26
  requests-oauthlib==1.3.0
 
27
  rpds-py==0.18.1
28
  scipy==1.10.1
29
  shapely==2.0.4
 
 
 
30
  soupsieve==2.5
 
 
 
 
 
31
  torch==2.3.1
32
  torchaudio==2.3.1
33
  torchvision==0.18.1
 
34
  tqdm==4.66.4
35
  traitlets==5.14.3
36
  trimesh==4.4.1
 
 
 
 
 
37
  wcwidth==0.2.13
 
 
1
+ altair==5.3.0
2
+ attrs==23.2.0
3
  beautifulsoup4==4.12.3
4
+ blinker==1.8.2
5
+ cachetools==5.3.3
6
+ certifi==2024.6.2
7
+ charset-normalizer==3.3.2
8
+ chumpy==0.70
9
+ click==8.1.7
10
+ contourpy==1.1.1
11
+ cycler==0.12.1
12
  einops==0.8.0
13
  fastjsonschema==2.20.0
14
+ filelock==3.15.4
15
+ fonttools==4.53.0
16
  freetype-py==2.4.0
17
  fsspec==2024.6.1
18
  ftfy==6.2.0
19
  gdown==5.2.0
20
+ gitdb==4.0.11
21
+ GitPython==3.1.43
22
  h5py==3.11.0
23
+ idna==3.7
24
  imageio==2.34.2
25
+ importlib_resources==6.4.0
26
+ Jinja2==3.1.4
27
  jsonschema==4.22.0
28
  jsonschema-specifications==2023.12.1
29
  jupyter_core==5.7.2
30
+ kiwisolver==1.4.5
31
  mapbox-earcut==1.0.1
32
+ markdown-it-py==3.0.0
33
+ MarkupSafe==2.1.5
34
+ matplotlib==3.7.5
35
+ mdurl==0.1.2
36
+ mpmath==1.3.0
37
  nbformat==5.10.4
38
+ networkx==3.1
39
  numpy==1.23.3
40
+ nvidia-cublas-cu12==12.1.3.1
41
+ nvidia-cuda-cupti-cu12==12.1.105
42
+ nvidia-cuda-nvrtc-cu12==12.1.105
43
+ nvidia-cuda-runtime-cu12==12.1.105
44
+ nvidia-cudnn-cu12==8.9.2.26
45
+ nvidia-cufft-cu12==11.0.2.54
46
+ nvidia-curand-cu12==10.3.2.106
47
+ nvidia-cusolver-cu12==11.4.5.107
48
+ nvidia-cusparse-cu12==12.1.0.106
49
+ nvidia-nccl-cu12==2.20.5
50
+ nvidia-nvjitlink-cu12==12.5.40
51
+ nvidia-nvtx-cu12==12.1.105
52
+ oauthlib==3.2.2
53
+ packaging==24.1
54
+ pandas==2.0.3
55
+ pillow==10.3.0
56
  pkgutil_resolve_name==1.3.10
57
+ platformdirs==4.2.2
58
+ plotly==5.22.0
59
  protobuf==3.20.3
60
+ pyarrow==16.1.0
61
+ pyasn1==0.4.8
62
  pyasn1-modules==0.2.8
63
+ pydeck==0.9.1
64
  pyglet==2.0.15
65
+ Pygments==2.18.0
66
+ pyparsing==3.1.2
67
  pyrender==0.1.45
68
+ PySocks==1.7.1
69
+ python-dateutil==2.9.0.post0
70
+ pytz==2024.1
71
  referencing==0.35.1
72
  regex==2024.5.15
73
+ requests==2.32.3
74
  requests-oauthlib==1.3.0
75
+ rich==13.7.1
76
  rpds-py==0.18.1
77
  scipy==1.10.1
78
  shapely==2.0.4
79
+ six==1.16.0
80
+ smmap==5.0.1
81
+ smplx==0.1.28
82
  soupsieve==2.5
83
+ streamlit==1.36.0
84
+ sympy==1.12.1
85
+ tenacity==8.4.2
86
+ toml==0.10.2
87
+ toolz==0.12.1
88
  torch==2.3.1
89
  torchaudio==2.3.1
90
  torchvision==0.18.1
91
+ tornado==6.4.1
92
  tqdm==4.66.4
93
  traitlets==5.14.3
94
  trimesh==4.4.1
95
+ triton==2.3.1
96
+ typing_extensions==4.12.2
97
+ tzdata==2024.1
98
+ urllib3==2.2.2
99
+ watchdog==4.0.1
100
  wcwidth==0.2.13
101
+ zipp==3.19.2