faraday commited on
Commit
86716b3
·
1 Parent(s): 88590fc
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import gradio as gr
3
  import shutil
4
  import subprocess
@@ -42,12 +41,13 @@ DEFAULT_MODEL_ARGS = {
42
  #'audio_attention_weight': 3.0
43
  }
44
 
45
- @spaces.GPU(duration=600)
 
46
  def infer(reference_image, audio_path, kps_sequence_save_path,
47
  output_path,
48
  retarget_strategy,
49
  reference_attention_weight, audio_attention_weight):
50
- INFERENCE_ENGINE = InferenceEngine(DEFAULT_MODEL_ARGS)
51
  INFERENCE_ENGINE.infer(
52
  reference_image, audio_path, kps_sequence_save_path,
53
  output_path,
@@ -57,7 +57,6 @@ def infer(reference_image, audio_path, kps_sequence_save_path,
57
  return output_path, kps_sequence_save_path
58
 
59
  # Function to run V-Express demo
60
- @spaces.GPU(duration=600)
61
  def run_demo(
62
  reference_image, audio, video,
63
  kps_path, output_path, retarget_strategy,
 
 
1
  import gradio as gr
2
  import shutil
3
  import subprocess
 
41
  #'audio_attention_weight': 3.0
42
  }
43
 
44
+ INFERENCE_ENGINE = InferenceEngine(DEFAULT_MODEL_ARGS)
45
+
46
  def infer(reference_image, audio_path, kps_sequence_save_path,
47
  output_path,
48
  retarget_strategy,
49
  reference_attention_weight, audio_attention_weight):
50
+ global INFERENCE_ENGINE
51
  INFERENCE_ENGINE.infer(
52
  reference_image, audio_path, kps_sequence_save_path,
53
  output_path,
 
57
  return output_path, kps_sequence_save_path
58
 
59
  # Function to run V-Express demo
 
60
  def run_demo(
61
  reference_image, audio, video,
62
  kps_path, output_path, retarget_strategy,
inference.py CHANGED
@@ -1,6 +1,3 @@
1
- import spaces
2
- import argparse
3
-
4
  import os
5
  import cv2
6
  import numpy as np
@@ -20,14 +17,14 @@ from pipelines import VExpressPipeline
20
  from pipelines.utils import draw_kps_image, save_video
21
  from pipelines.utils import retarget_kps
22
 
23
- @spaces.GPU
24
  def load_reference_net(unet_config_path, reference_net_path, dtype, device):
25
  reference_net = UNet2DConditionModel.from_config(unet_config_path).to(dtype=dtype, device=device)
26
  reference_net.load_state_dict(torch.load(reference_net_path, map_location="cpu"), strict=False)
27
  print(f'Loaded weights of Reference Net from {reference_net_path}.')
28
  return reference_net
29
 
30
- @spaces.GPU
31
  def load_denoising_unet(unet_config_path, denoising_unet_path, motion_module_path, dtype, device):
32
  inference_config_path = './inference_v2.yaml'
33
  inference_config = OmegaConf.load(inference_config_path)
@@ -43,14 +40,14 @@ def load_denoising_unet(unet_config_path, denoising_unet_path, motion_module_pat
43
 
44
  return denoising_unet
45
 
46
- @spaces.GPU
47
  def load_v_kps_guider(v_kps_guider_path, dtype, device):
48
  v_kps_guider = VKpsGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device)
49
  v_kps_guider.load_state_dict(torch.load(v_kps_guider_path, map_location="cpu"))
50
  print(f'Loaded weights of V-Kps Guider from {v_kps_guider_path}.')
51
  return v_kps_guider
52
 
53
- @spaces.GPU
54
  def load_audio_projection(
55
  audio_projection_path,
56
  dtype,
@@ -76,7 +73,7 @@ def load_audio_projection(
76
  print(f'Loaded weights of Audio Projection from {audio_projection_path}.')
77
  return audio_projection
78
 
79
- @spaces.GPU
80
  def get_scheduler():
81
  inference_config_path = './inference_v2.yaml'
82
  inference_config = OmegaConf.load(inference_config_path)
@@ -86,7 +83,7 @@ def get_scheduler():
86
 
87
  class InferenceEngine(object):
88
 
89
- @spaces.GPU
90
  def __init__(self, args):
91
  self.init_params(args)
92
  self.load_models()
@@ -94,7 +91,7 @@ class InferenceEngine(object):
94
  self.set_vexpress_pipeline()
95
  self.set_face_analysis_app()
96
 
97
- @spaces.GPU
98
  def init_params(self, args):
99
  for key, value in args.items():
100
  setattr(self, key, value)
@@ -103,7 +100,7 @@ class InferenceEngine(object):
103
  print("Image height: ", self.image_height)
104
 
105
 
106
- @spaces.GPU
107
  def load_models(self):
108
  self.device = torch.device(f'cuda:{self.gpu_id}')
109
  self.dtype = torch.float16 if self.dtype == 'fp16' else torch.float32
@@ -134,11 +131,11 @@ class InferenceEngine(object):
134
  else:
135
  raise ValueError("xformers is not available. Make sure it is installed correctly")
136
 
137
- @spaces.GPU
138
  def set_generator(self):
139
  self.generator = torch.manual_seed(self.seed)
140
 
141
- @spaces.GPU
142
  def set_vexpress_pipeline(self):
143
  print("VAE exists (2): ", self.vae)
144
  self.pipeline = VExpressPipeline(
@@ -152,7 +149,7 @@ class InferenceEngine(object):
152
  scheduler=self.scheduler,
153
  ).to(dtype=self.dtype, device=self.device)
154
 
155
- @spaces.GPU
156
  def set_face_analysis_app(self):
157
  self.app = FaceAnalysis(
158
  providers=['CUDAExecutionProvider'],
@@ -161,7 +158,7 @@ class InferenceEngine(object):
161
  )
162
  self.app.prepare(ctx_id=0, det_size=(self.image_height, self.image_width))
163
 
164
- @spaces.GPU
165
  def get_reference_image_for_kps(self, reference_image_path):
166
  reference_image = Image.open(reference_image_path).convert('RGB')
167
  print("Image width ???", self.image_width)
@@ -172,7 +169,7 @@ class InferenceEngine(object):
172
  reference_kps = self.app.get(reference_image_for_kps)[0].kps[:3]
173
  return reference_image, reference_image_for_kps, reference_kps
174
 
175
- @spaces.GPU
176
  def get_waveform_video_length(self, audio_path):
177
  _, audio_waveform, meta_info = torchvision.io.read_video(audio_path, pts_unit='sec')
178
  audio_sampling_rate = meta_info['audio_fps']
@@ -190,7 +187,7 @@ class InferenceEngine(object):
190
  print(f'The corresponding video length is {video_length}.')
191
  return audio_waveform, video_length
192
 
193
- @spaces.GPU
194
  def get_kps_sequence(self, kps_path, reference_kps, video_length, retarget_strategy):
195
  if kps_path != "":
196
  assert os.path.exists(kps_path), f'{kps_path} does not exist'
@@ -213,7 +210,7 @@ class InferenceEngine(object):
213
 
214
  return kps_sequence
215
 
216
- @spaces.GPU
217
  def get_kps_images(self, kps_sequence, reference_image_for_kps, video_length):
218
  kps_images = []
219
  for i in range(video_length):
@@ -222,7 +219,6 @@ class InferenceEngine(object):
222
  kps_images.append(Image.fromarray(kps_image))
223
  return kps_images
224
 
225
- @spaces.GPU(duration=600)
226
  def get_video_latents(self, reference_image, kps_images, audio_waveform, video_length, reference_attention_weight, audio_attention_weight):
227
  vae_scale_factor = 8
228
  latent_height = self.image_height // vae_scale_factor
@@ -252,19 +248,18 @@ class InferenceEngine(object):
252
 
253
  return video_latents
254
 
255
- @spaces.GPU
256
  def get_video_tensor(self, video_latents):
257
  video_tensor = self.pipeline.decode_latents(video_latents)
258
  if isinstance(video_tensor, np.ndarray):
259
  video_tensor = torch.from_numpy(video_tensor)
260
  return video_tensor
261
 
262
- @spaces.GPU
263
  def save_video_tensor(self, video_tensor, audio_path, output_path):
264
  save_video(video_tensor, audio_path, output_path, self.fps)
265
  print(f'The generated video has been saved at {output_path}.')
266
 
267
- @spaces.GPU(duration=600)
268
  def infer(
269
  self,
270
  reference_image_path, audio_path, kps_path,
 
 
 
 
1
  import os
2
  import cv2
3
  import numpy as np
 
17
  from pipelines.utils import draw_kps_image, save_video
18
  from pipelines.utils import retarget_kps
19
 
20
+
21
  def load_reference_net(unet_config_path, reference_net_path, dtype, device):
22
  reference_net = UNet2DConditionModel.from_config(unet_config_path).to(dtype=dtype, device=device)
23
  reference_net.load_state_dict(torch.load(reference_net_path, map_location="cpu"), strict=False)
24
  print(f'Loaded weights of Reference Net from {reference_net_path}.')
25
  return reference_net
26
 
27
+
28
  def load_denoising_unet(unet_config_path, denoising_unet_path, motion_module_path, dtype, device):
29
  inference_config_path = './inference_v2.yaml'
30
  inference_config = OmegaConf.load(inference_config_path)
 
40
 
41
  return denoising_unet
42
 
43
+
44
  def load_v_kps_guider(v_kps_guider_path, dtype, device):
45
  v_kps_guider = VKpsGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device)
46
  v_kps_guider.load_state_dict(torch.load(v_kps_guider_path, map_location="cpu"))
47
  print(f'Loaded weights of V-Kps Guider from {v_kps_guider_path}.')
48
  return v_kps_guider
49
 
50
+
51
  def load_audio_projection(
52
  audio_projection_path,
53
  dtype,
 
73
  print(f'Loaded weights of Audio Projection from {audio_projection_path}.')
74
  return audio_projection
75
 
76
+
77
  def get_scheduler():
78
  inference_config_path = './inference_v2.yaml'
79
  inference_config = OmegaConf.load(inference_config_path)
 
83
 
84
  class InferenceEngine(object):
85
 
86
+
87
  def __init__(self, args):
88
  self.init_params(args)
89
  self.load_models()
 
91
  self.set_vexpress_pipeline()
92
  self.set_face_analysis_app()
93
 
94
+
95
  def init_params(self, args):
96
  for key, value in args.items():
97
  setattr(self, key, value)
 
100
  print("Image height: ", self.image_height)
101
 
102
 
103
+
104
  def load_models(self):
105
  self.device = torch.device(f'cuda:{self.gpu_id}')
106
  self.dtype = torch.float16 if self.dtype == 'fp16' else torch.float32
 
131
  else:
132
  raise ValueError("xformers is not available. Make sure it is installed correctly")
133
 
134
+
135
  def set_generator(self):
136
  self.generator = torch.manual_seed(self.seed)
137
 
138
+
139
  def set_vexpress_pipeline(self):
140
  print("VAE exists (2): ", self.vae)
141
  self.pipeline = VExpressPipeline(
 
149
  scheduler=self.scheduler,
150
  ).to(dtype=self.dtype, device=self.device)
151
 
152
+
153
  def set_face_analysis_app(self):
154
  self.app = FaceAnalysis(
155
  providers=['CUDAExecutionProvider'],
 
158
  )
159
  self.app.prepare(ctx_id=0, det_size=(self.image_height, self.image_width))
160
 
161
+
162
  def get_reference_image_for_kps(self, reference_image_path):
163
  reference_image = Image.open(reference_image_path).convert('RGB')
164
  print("Image width ???", self.image_width)
 
169
  reference_kps = self.app.get(reference_image_for_kps)[0].kps[:3]
170
  return reference_image, reference_image_for_kps, reference_kps
171
 
172
+
173
  def get_waveform_video_length(self, audio_path):
174
  _, audio_waveform, meta_info = torchvision.io.read_video(audio_path, pts_unit='sec')
175
  audio_sampling_rate = meta_info['audio_fps']
 
187
  print(f'The corresponding video length is {video_length}.')
188
  return audio_waveform, video_length
189
 
190
+
191
  def get_kps_sequence(self, kps_path, reference_kps, video_length, retarget_strategy):
192
  if kps_path != "":
193
  assert os.path.exists(kps_path), f'{kps_path} does not exist'
 
210
 
211
  return kps_sequence
212
 
213
+
214
  def get_kps_images(self, kps_sequence, reference_image_for_kps, video_length):
215
  kps_images = []
216
  for i in range(video_length):
 
219
  kps_images.append(Image.fromarray(kps_image))
220
  return kps_images
221
 
 
222
  def get_video_latents(self, reference_image, kps_images, audio_waveform, video_length, reference_attention_weight, audio_attention_weight):
223
  vae_scale_factor = 8
224
  latent_height = self.image_height // vae_scale_factor
 
248
 
249
  return video_latents
250
 
251
+
252
  def get_video_tensor(self, video_latents):
253
  video_tensor = self.pipeline.decode_latents(video_latents)
254
  if isinstance(video_tensor, np.ndarray):
255
  video_tensor = torch.from_numpy(video_tensor)
256
  return video_tensor
257
 
258
+
259
  def save_video_tensor(self, video_tensor, audio_path, output_path):
260
  save_video(video_tensor, audio_path, output_path, self.fps)
261
  print(f'The generated video has been saved at {output_path}.')
262
 
 
263
  def infer(
264
  self,
265
  reference_image_path, audio_path, kps_path,
requirements.txt CHANGED
@@ -15,4 +15,3 @@ tqdm==4.66.1
15
  xformers==0.0.20
16
  accelerate==0.19.0
17
  gitpython==3.1.31
18
- spaces==0.28.3
 
15
  xformers==0.0.20
16
  accelerate==0.19.0
17
  gitpython==3.1.31
 
scripts/extract_kps_sequence_and_audio.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import argparse
3
 
4
  import os
@@ -7,7 +6,6 @@ import torch
7
  from insightface.app import FaceAnalysis
8
  from imageio_ffmpeg import get_ffmpeg_exe
9
 
10
- @spaces.GPU
11
  def main(args):
12
  app = FaceAnalysis(
13
  providers=['CUDAExecutionProvider'],
 
 
1
  import argparse
2
 
3
  import os
 
6
  from insightface.app import FaceAnalysis
7
  from imageio_ffmpeg import get_ffmpeg_exe
8
 
 
9
  def main(args):
10
  app = FaceAnalysis(
11
  providers=['CUDAExecutionProvider'],