Pie31415 commited on
Commit
9887bdf
Β·
1 Parent(s): 57885e4

implemented video inference

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +72 -6
  3. requirements.txt +2 -1
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Rome
3
- emoji: πŸ’©
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
 
1
  ---
2
  title: Rome
3
+ emoji: πŸ˜‚
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
app.py CHANGED
@@ -1,7 +1,16 @@
1
  import sys
2
  import torch
3
- import gradio as gr
4
  import pickle
 
 
 
 
 
 
 
 
 
 
5
 
6
  from easydict import EasyDict as edict
7
  from huggingface_hub import hf_hub_download
@@ -11,6 +20,7 @@ sys.path.append('./DECA')
11
 
12
  from rome.infer import Infer
13
  from rome.src.utils.processing import process_black_shape, tensor2image
 
14
 
15
  # loading models ---- create model repo
16
  default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt')
@@ -128,8 +138,64 @@ def image_inference(
128
  out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
129
  return res[..., ::-1]
130
 
131
- def video_inference():
132
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  with gr.Blocks() as demo:
135
  gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**")
@@ -151,8 +217,8 @@ with gr.Blocks() as demo:
151
  image_button = gr.Button("Predict")
152
  with gr.Tab("Video Inference"):
153
  with gr.Row():
154
- source_video = gr.Video(label="source video", )
155
- driver_image_for_vid = gr.Image(type="pil", label="driver image", show_label=True)
156
  video_output = gr.Image()
157
  video_button = gr.Button("Predict")
158
 
@@ -168,6 +234,6 @@ with gr.Blocks() as demo:
168
  )
169
 
170
  image_button.click(image_inference, inputs=[source_img, driver_img], outputs=image_output)
171
- video_button.click(None, inputs=[source_video, driver_image_for_vid], outputs=video_output)
172
 
173
  demo.launch()
 
1
  import sys
2
  import torch
 
3
  import pickle
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+ from collections import defaultdict
10
+ from glob import glob
11
+
12
+ from matplotlib import pyplot as plt
13
+ from matplotlib import animation
14
 
15
  from easydict import EasyDict as edict
16
  from huggingface_hub import hf_hub_download
 
20
 
21
  from rome.infer import Infer
22
  from rome.src.utils.processing import process_black_shape, tensor2image
23
+ from rome.src.utils.visuals import mask_errosion
24
 
25
  # loading models ---- create model repo
26
  default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt')
 
138
  out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
139
  return res[..., ::-1]
140
 
141
+ def extract_frames(driver_vid):
142
+ image_frames = []
143
+ vid = cv2.VideoCapture(driver_vid) # path to mp4
144
+
145
+ while True:
146
+ success, img = vid.read()
147
+
148
+ if not success: break
149
+
150
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
151
+ pil_img = Image.fromarray(img)
152
+ image_frames.append(pil_img)
153
+
154
+ return image_frames
155
+
156
+ def video_inference(source_img, driver_vid):
157
+ image_frames = extract_frames(driver_vid)
158
+
159
+ resulted_imgs = defaultdict(list)
160
+
161
+ video_folder = 'jenya_driver/'
162
+ image_frames = sorted(glob(f"{video_folder}/*", recursive=True), key=lambda x: int(x.split('/')[-1][:-4]))
163
+
164
+ mask_hard_threshold = 0.5
165
+ N = len(image_frames)//20
166
+ for i in range(0, N, 4):
167
+ new_out = infer.evaluate(source_img, Image.open(image_frames[i]),
168
+ source_information_for_reuse=out.get('source_information'))
169
+
170
+ mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float()
171
+ mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255)
172
+ render = new_out['pred_target_img'].cpu() * (mask_pred) + (1 - mask_pred)
173
+
174
+ normals = process_black_shape(((new_out['pred_target_normal'][0].cpu() + 1) / 2 * mask_pred + (1 - mask_pred) ) )
175
+ normals[normals==0.5]=1.
176
+
177
+ resulted_imgs['res_normal'].append(tensor2image(normals))
178
+ resulted_imgs['res_mesh_images'].append(tensor2image(new_out['pred_target_shape_img'][0]))
179
+ resulted_imgs['res_renders'].append(tensor2image(render[0]))
180
+
181
+ video = np.array(resulted_imgs['res_renders'])
182
+
183
+ fig = plt.figure()
184
+ im = plt.imshow(video[0,:,:,::-1])
185
+ plt.axis('off')
186
+ plt.close() # this is required to not display the generated image
187
+
188
+ def init():
189
+ im.set_data(video[0,:,:,::-1])
190
+
191
+ def animate(i):
192
+ im.set_data(video[i,:,:,::-1])
193
+ return im
194
+
195
+ anim = animation.FuncAnimation(fig, animate, init_func=init,
196
+ frames=video.shape[0], interval=30)
197
+
198
+ return anim
199
 
200
  with gr.Blocks() as demo:
201
  gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**")
 
217
  image_button = gr.Button("Predict")
218
  with gr.Tab("Video Inference"):
219
  with gr.Row():
220
+ source_img2 = gr.Image(type="pil", label="source image", show_label=True)
221
+ driver_vid = gr.Video(label="driver video")
222
  video_output = gr.Image()
223
  video_button = gr.Button("Predict")
224
 
 
234
  )
235
 
236
  image_button.click(image_inference, inputs=[source_img, driver_img], outputs=image_output)
237
+ video_button.click(video_inference, inputs=[source_img2, driver_vid], outputs=video_output)
238
 
239
  demo.launch()
requirements.txt CHANGED
@@ -8,4 +8,5 @@ matplotlib
8
  pillow
9
  https://download.pytorch.org/whl/cu101/torch-1.6.0%2Bcu101-cp38-cp38-linux_x86_64.whl
10
  https://download.pytorch.org/whl/cu101/torchvision-0.7.0%2Bcu101-cp38-cp38-linux_x86_64.whl
11
- easydict
 
 
8
  pillow
9
  https://download.pytorch.org/whl/cu101/torch-1.6.0%2Bcu101-cp38-cp38-linux_x86_64.whl
10
  https://download.pytorch.org/whl/cu101/torchvision-0.7.0%2Bcu101-cp38-cp38-linux_x86_64.whl
11
+ easydict
12
+ opencv