zejunyang commited on
Commit
3a0bff5
·
1 Parent(s): 18f04c7
Files changed (3) hide show
  1. app.py +6 -6
  2. src/audio2vid.py +63 -64
  3. src/vid2vid.py +59 -62
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
 
3
- # from src.audio2vid import audio2video
4
- # from src.vid2vid import video2video
5
 
6
- from src.create_modules import Processer
7
 
8
  title = r"""
9
  <h1>AniPortrait</h1>
@@ -13,7 +13,7 @@ description = r"""
13
  <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/Zejun-Yang/AniPortrait' target='_blank'><b>AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations</b></a>.<br>
14
  """
15
 
16
- main_processer = Processer()
17
 
18
  with gr.Blocks() as demo:
19
 
@@ -77,13 +77,13 @@ with gr.Blocks() as demo:
77
  )
78
 
79
  a2v_botton.click(
80
- fn=main_processer.audio2video,
81
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
82
  a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed],
83
  outputs=[a2v_output_video, a2v_ref_img]
84
  )
85
  v2v_botton.click(
86
- fn=main_processer.video2video,
87
  inputs=[v2v_ref_img, v2v_source_video,
88
  v2v_size_slider, v2v_step_slider, v2v_length, v2v_seed],
89
  outputs=[v2v_output_video, v2v_ref_img]
 
1
  import gradio as gr
2
 
3
+ from src.audio2vid import audio2video
4
+ from src.vid2vid import video2video
5
 
6
+ # from src.create_modules import Processer
7
 
8
  title = r"""
9
  <h1>AniPortrait</h1>
 
13
  <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/Zejun-Yang/AniPortrait' target='_blank'><b>AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations</b></a>.<br>
14
  """
15
 
16
+ # main_processer = Processer()
17
 
18
  with gr.Blocks() as demo:
19
 
 
77
  )
78
 
79
  a2v_botton.click(
80
+ fn=audio2video,
81
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
82
  a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed],
83
  outputs=[a2v_output_video, a2v_ref_img]
84
  )
85
  v2v_botton.click(
86
+ fn=video2video,
87
  inputs=[v2v_ref_img, v2v_source_video,
88
  v2v_size_slider, v2v_step_slider, v2v_length, v2v_seed],
89
  outputs=[v2v_output_video, v2v_ref_img]
src/audio2vid.py CHANGED
@@ -9,27 +9,26 @@ import spaces
9
  from scipy.spatial.transform import Rotation as R
10
  from scipy.interpolate import interp1d
11
 
12
- # from diffusers import AutoencoderKL, DDIMScheduler
13
- # from einops import repeat
14
  from omegaconf import OmegaConf
15
  from PIL import Image
16
  from torchvision import transforms
17
- # from transformers import CLIPVisionModelWithProjection
18
 
19
 
20
- # from src.models.pose_guider import PoseGuider
21
- # from src.models.unet_2d_condition import UNet2DConditionModel
22
- # from src.models.unet_3d import UNet3DConditionModel
23
- # from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
24
  from src.utils.util import save_videos_grid
25
 
26
- # from src.audio_models.model import Audio2MeshModel
27
  from src.utils.audio_util import prepare_audio_feature
28
- # from src.utils.mp_utils import LMKExtractor
29
- # from src.utils.draw_util import FaceMeshVisualizer
30
  from src.utils.pose_util import project_points
31
  from src.utils.crop_face_single import crop_face
32
- from src.create_modules import lmk_extractor, vis, a2m_model, pipe
33
 
34
 
35
  def matrix_to_euler_and_translation(matrix):
@@ -51,7 +50,7 @@ def smooth_pose_seq(pose_seq, window_size=5):
51
  return smoothed_pose_seq
52
 
53
  def get_headpose_temp(input_video):
54
- # lmk_extractor = LMKExtractor()
55
  cap = cv2.VideoCapture(input_video)
56
 
57
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -100,70 +99,70 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
100
 
101
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
102
 
103
- # if config.weight_dtype == "fp16":
104
- # weight_dtype = torch.float16
105
- # else:
106
- # weight_dtype = torch.float32
107
 
108
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
109
- # # prepare model
110
- # a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
111
- # a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
112
- # a2m_model.cuda().eval()
113
 
114
- # vae = AutoencoderKL.from_pretrained(
115
- # config.pretrained_vae_path,
116
- # ).to("cuda", dtype=weight_dtype)
117
 
118
- # reference_unet = UNet2DConditionModel.from_pretrained(
119
- # config.pretrained_base_model_path,
120
- # subfolder="unet",
121
- # ).to(dtype=weight_dtype, device="cuda")
122
 
123
- # inference_config_path = config.inference_config
124
- # infer_config = OmegaConf.load(inference_config_path)
125
- # denoising_unet = UNet3DConditionModel.from_pretrained_2d(
126
- # config.pretrained_base_model_path,
127
- # config.motion_module_path,
128
- # subfolder="unet",
129
- # unet_additional_kwargs=infer_config.unet_additional_kwargs,
130
- # ).to(dtype=weight_dtype, device="cuda")
131
 
132
 
133
- # pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
134
 
135
- # image_enc = CLIPVisionModelWithProjection.from_pretrained(
136
- # config.image_encoder_path
137
- # ).to(dtype=weight_dtype, device="cuda")
138
 
139
- # sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
140
- # scheduler = DDIMScheduler(**sched_kwargs)
141
 
142
  generator = torch.manual_seed(seed)
143
 
144
  width, height = size, size
145
 
146
- # # load pretrained weights
147
- # denoising_unet.load_state_dict(
148
- # torch.load(config.denoising_unet_path, map_location="cpu"),
149
- # strict=False,
150
- # )
151
- # reference_unet.load_state_dict(
152
- # torch.load(config.reference_unet_path, map_location="cpu"),
153
- # )
154
- # pose_guider.load_state_dict(
155
- # torch.load(config.pose_guider_path, map_location="cpu"),
156
- # )
157
-
158
- # pipe = Pose2VideoPipeline(
159
- # vae=vae,
160
- # image_encoder=image_enc,
161
- # reference_unet=reference_unet,
162
- # denoising_unet=denoising_unet,
163
- # pose_guider=pose_guider,
164
- # scheduler=scheduler,
165
- # )
166
- # pipe = pipe.to("cuda", dtype=weight_dtype)
167
 
168
  date_str = datetime.now().strftime("%Y%m%d")
169
  time_str = datetime.now().strftime("%H%M")
@@ -172,8 +171,8 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
172
  save_dir = Path(f"output/{date_str}/{save_dir_name}")
173
  save_dir.mkdir(exist_ok=True, parents=True)
174
 
175
- # lmk_extractor = LMKExtractor()
176
- # vis = FaceMeshVisualizer(forehead_edge=False)
177
 
178
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
179
  ref_image_np = crop_face(ref_image_np, lmk_extractor)
 
9
  from scipy.spatial.transform import Rotation as R
10
  from scipy.interpolate import interp1d
11
 
12
+ from diffusers import AutoencoderKL, DDIMScheduler
13
+ from einops import repeat
14
  from omegaconf import OmegaConf
15
  from PIL import Image
16
  from torchvision import transforms
17
+ from transformers import CLIPVisionModelWithProjection
18
 
19
 
20
+ from src.models.pose_guider import PoseGuider
21
+ from src.models.unet_2d_condition import UNet2DConditionModel
22
+ from src.models.unet_3d import UNet3DConditionModel
23
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
24
  from src.utils.util import save_videos_grid
25
 
26
+ from src.audio_models.model import Audio2MeshModel
27
  from src.utils.audio_util import prepare_audio_feature
28
+ from src.utils.mp_utils import LMKExtractor
29
+ from src.utils.draw_util import FaceMeshVisualizer
30
  from src.utils.pose_util import project_points
31
  from src.utils.crop_face_single import crop_face
 
32
 
33
 
34
  def matrix_to_euler_and_translation(matrix):
 
50
  return smoothed_pose_seq
51
 
52
  def get_headpose_temp(input_video):
53
+ lmk_extractor = LMKExtractor()
54
  cap = cv2.VideoCapture(input_video)
55
 
56
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
99
 
100
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
101
 
102
+ if config.weight_dtype == "fp16":
103
+ weight_dtype = torch.float16
104
+ else:
105
+ weight_dtype = torch.float32
106
 
107
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
108
+ # prepare model
109
+ a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
110
+ a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
111
+ a2m_model.cuda().eval()
112
 
113
+ vae = AutoencoderKL.from_pretrained(
114
+ config.pretrained_vae_path,
115
+ ).to("cuda", dtype=weight_dtype)
116
 
117
+ reference_unet = UNet2DConditionModel.from_pretrained(
118
+ config.pretrained_base_model_path,
119
+ subfolder="unet",
120
+ ).to(dtype=weight_dtype, device="cuda")
121
 
122
+ inference_config_path = config.inference_config
123
+ infer_config = OmegaConf.load(inference_config_path)
124
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
125
+ config.pretrained_base_model_path,
126
+ config.motion_module_path,
127
+ subfolder="unet",
128
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
129
+ ).to(dtype=weight_dtype, device="cuda")
130
 
131
 
132
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
133
 
134
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
135
+ config.image_encoder_path
136
+ ).to(dtype=weight_dtype, device="cuda")
137
 
138
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
139
+ scheduler = DDIMScheduler(**sched_kwargs)
140
 
141
  generator = torch.manual_seed(seed)
142
 
143
  width, height = size, size
144
 
145
+ # load pretrained weights
146
+ denoising_unet.load_state_dict(
147
+ torch.load(config.denoising_unet_path, map_location="cpu"),
148
+ strict=False,
149
+ )
150
+ reference_unet.load_state_dict(
151
+ torch.load(config.reference_unet_path, map_location="cpu"),
152
+ )
153
+ pose_guider.load_state_dict(
154
+ torch.load(config.pose_guider_path, map_location="cpu"),
155
+ )
156
+
157
+ pipe = Pose2VideoPipeline(
158
+ vae=vae,
159
+ image_encoder=image_enc,
160
+ reference_unet=reference_unet,
161
+ denoising_unet=denoising_unet,
162
+ pose_guider=pose_guider,
163
+ scheduler=scheduler,
164
+ )
165
+ pipe = pipe.to("cuda", dtype=weight_dtype)
166
 
167
  date_str = datetime.now().strftime("%Y%m%d")
168
  time_str = datetime.now().strftime("%H%M")
 
171
  save_dir = Path(f"output/{date_str}/{save_dir_name}")
172
  save_dir.mkdir(exist_ok=True, parents=True)
173
 
174
+ lmk_extractor = LMKExtractor()
175
+ vis = FaceMeshVisualizer(forehead_edge=False)
176
 
177
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
178
  ref_image_np = crop_face(ref_image_np, lmk_extractor)
src/vid2vid.py CHANGED
@@ -7,89 +7,88 @@ import numpy as np
7
  import cv2
8
  import torch
9
  import spaces
10
- # from diffusers import AutoencoderKL, DDIMScheduler
11
- # from einops import repeat
12
- # from omegaconf import OmegaConf
13
  from PIL import Image
14
  from torchvision import transforms
15
- # from transformers import CLIPVisionModelWithProjection
16
 
17
- # from src.models.pose_guider import PoseGuider
18
- # from src.models.unet_2d_condition import UNet2DConditionModel
19
- # from src.models.unet_3d import UNet3DConditionModel
20
- # from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
21
  from src.utils.util import get_fps, read_frames, save_videos_grid
22
 
23
- # from src.utils.mp_utils import LMKExtractor
24
- # from src.utils.draw_util import FaceMeshVisualizer
25
  from src.utils.pose_util import project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix
26
  from src.audio2vid import smooth_pose_seq
27
  from src.utils.crop_face_single import crop_face
28
- from src.create_modules import lmk_extractor, vis, pipe
29
 
30
  @spaces.GPU
31
  def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
32
  cfg = 3.5
33
 
34
- # config = OmegaConf.load('./configs/prompts/animation_facereenac.yaml')
35
 
36
- # if config.weight_dtype == "fp16":
37
- # weight_dtype = torch.float16
38
- # else:
39
- # weight_dtype = torch.float32
40
 
41
- # vae = AutoencoderKL.from_pretrained(
42
- # config.pretrained_vae_path,
43
- # ).to("cuda", dtype=weight_dtype)
44
 
45
- # reference_unet = UNet2DConditionModel.from_pretrained(
46
- # config.pretrained_base_model_path,
47
- # subfolder="unet",
48
- # ).to(dtype=weight_dtype, device="cuda")
49
 
50
- # inference_config_path = config.inference_config
51
- # infer_config = OmegaConf.load(inference_config_path)
52
- # denoising_unet = UNet3DConditionModel.from_pretrained_2d(
53
- # config.pretrained_base_model_path,
54
- # config.motion_module_path,
55
- # subfolder="unet",
56
- # unet_additional_kwargs=infer_config.unet_additional_kwargs,
57
- # ).to(dtype=weight_dtype, device="cuda")
58
 
59
- # pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
60
 
61
- # image_enc = CLIPVisionModelWithProjection.from_pretrained(
62
- # config.image_encoder_path
63
- # ).to(dtype=weight_dtype, device="cuda")
64
 
65
- # sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
66
- # scheduler = DDIMScheduler(**sched_kwargs)
67
 
68
  generator = torch.manual_seed(seed)
69
 
70
  width, height = size, size
71
 
72
- # # load pretrained weights
73
- # denoising_unet.load_state_dict(
74
- # torch.load(config.denoising_unet_path, map_location="cpu"),
75
- # strict=False,
76
- # )
77
- # reference_unet.load_state_dict(
78
- # torch.load(config.reference_unet_path, map_location="cpu"),
79
- # )
80
- # pose_guider.load_state_dict(
81
- # torch.load(config.pose_guider_path, map_location="cpu"),
82
- # )
83
-
84
- # pipe = Pose2VideoPipeline(
85
- # vae=vae,
86
- # image_encoder=image_enc,
87
- # reference_unet=reference_unet,
88
- # denoising_unet=denoising_unet,
89
- # pose_guider=pose_guider,
90
- # scheduler=scheduler,
91
- # )
92
- # pipe = pipe.to("cuda", dtype=weight_dtype)
93
 
94
  date_str = datetime.now().strftime("%Y%m%d")
95
  time_str = datetime.now().strftime("%H%M")
@@ -99,11 +98,9 @@ def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
99
  save_dir.mkdir(exist_ok=True, parents=True)
100
 
101
 
102
- # lmk_extractor = LMKExtractor()
103
- # vis = FaceMeshVisualizer(forehead_edge=False)
104
 
105
-
106
-
107
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
108
  ref_image_np = crop_face(ref_image_np, lmk_extractor)
109
  if ref_image_np is None:
 
7
  import cv2
8
  import torch
9
  import spaces
10
+ from diffusers import AutoencoderKL, DDIMScheduler
11
+ from einops import repeat
12
+ from omegaconf import OmegaConf
13
  from PIL import Image
14
  from torchvision import transforms
15
+ from transformers import CLIPVisionModelWithProjection
16
 
17
+ from src.models.pose_guider import PoseGuider
18
+ from src.models.unet_2d_condition import UNet2DConditionModel
19
+ from src.models.unet_3d import UNet3DConditionModel
20
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
21
  from src.utils.util import get_fps, read_frames, save_videos_grid
22
 
23
+ from src.utils.mp_utils import LMKExtractor
24
+ from src.utils.draw_util import FaceMeshVisualizer
25
  from src.utils.pose_util import project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix
26
  from src.audio2vid import smooth_pose_seq
27
  from src.utils.crop_face_single import crop_face
 
28
 
29
  @spaces.GPU
30
  def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
31
  cfg = 3.5
32
 
33
+ config = OmegaConf.load('./configs/prompts/animation_facereenac.yaml')
34
 
35
+ if config.weight_dtype == "fp16":
36
+ weight_dtype = torch.float16
37
+ else:
38
+ weight_dtype = torch.float32
39
 
40
+ vae = AutoencoderKL.from_pretrained(
41
+ config.pretrained_vae_path,
42
+ ).to("cuda", dtype=weight_dtype)
43
 
44
+ reference_unet = UNet2DConditionModel.from_pretrained(
45
+ config.pretrained_base_model_path,
46
+ subfolder="unet",
47
+ ).to(dtype=weight_dtype, device="cuda")
48
 
49
+ inference_config_path = config.inference_config
50
+ infer_config = OmegaConf.load(inference_config_path)
51
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
52
+ config.pretrained_base_model_path,
53
+ config.motion_module_path,
54
+ subfolder="unet",
55
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
56
+ ).to(dtype=weight_dtype, device="cuda")
57
 
58
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
59
 
60
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
61
+ config.image_encoder_path
62
+ ).to(dtype=weight_dtype, device="cuda")
63
 
64
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
65
+ scheduler = DDIMScheduler(**sched_kwargs)
66
 
67
  generator = torch.manual_seed(seed)
68
 
69
  width, height = size, size
70
 
71
+ # load pretrained weights
72
+ denoising_unet.load_state_dict(
73
+ torch.load(config.denoising_unet_path, map_location="cpu"),
74
+ strict=False,
75
+ )
76
+ reference_unet.load_state_dict(
77
+ torch.load(config.reference_unet_path, map_location="cpu"),
78
+ )
79
+ pose_guider.load_state_dict(
80
+ torch.load(config.pose_guider_path, map_location="cpu"),
81
+ )
82
+
83
+ pipe = Pose2VideoPipeline(
84
+ vae=vae,
85
+ image_encoder=image_enc,
86
+ reference_unet=reference_unet,
87
+ denoising_unet=denoising_unet,
88
+ pose_guider=pose_guider,
89
+ scheduler=scheduler,
90
+ )
91
+ pipe = pipe.to("cuda", dtype=weight_dtype)
92
 
93
  date_str = datetime.now().strftime("%Y%m%d")
94
  time_str = datetime.now().strftime("%H%M")
 
98
  save_dir.mkdir(exist_ok=True, parents=True)
99
 
100
 
101
+ lmk_extractor = LMKExtractor()
102
+ vis = FaceMeshVisualizer(forehead_edge=False)
103
 
 
 
104
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
105
  ref_image_np = crop_face(ref_image_np, lmk_extractor)
106
  if ref_image_np is None: