cocktailpeanut commited on
Commit
7f1584d
1 Parent(s): fa7d98a
Files changed (2) hide show
  1. app.py +15 -8
  2. requirements.txt +7 -6
app.py CHANGED
@@ -31,6 +31,13 @@ from src.utils.crop_face_single import crop_face
31
  from src.audio2vid import get_headpose_temp, smooth_pose_seq
32
  from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
33
 
 
 
 
 
 
 
 
34
 
35
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
36
  if config.weight_dtype == "fp16":
@@ -42,16 +49,16 @@ audio_infer_config = OmegaConf.load(config.audio_inference_config)
42
  # prepare model
43
  a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
44
  a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
45
- a2m_model.cuda().eval()
46
 
47
  vae = AutoencoderKL.from_pretrained(
48
  config.pretrained_vae_path,
49
- ).to("cuda", dtype=weight_dtype)
50
 
51
  reference_unet = UNet2DConditionModel.from_pretrained(
52
  config.pretrained_base_model_path,
53
  subfolder="unet",
54
- ).to(dtype=weight_dtype, device="cuda")
55
 
56
  inference_config_path = config.inference_config
57
  infer_config = OmegaConf.load(inference_config_path)
@@ -60,13 +67,13 @@ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
60
  config.motion_module_path,
61
  subfolder="unet",
62
  unet_additional_kwargs=infer_config.unet_additional_kwargs,
63
- ).to(dtype=weight_dtype, device="cuda")
64
 
65
- pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
66
 
67
  image_enc = CLIPVisionModelWithProjection.from_pretrained(
68
  config.image_encoder_path
69
- ).to(dtype=weight_dtype, device="cuda")
70
 
71
  sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
72
  scheduler = DDIMScheduler(**sched_kwargs)
@@ -91,7 +98,7 @@ pipe = Pose2VideoPipeline(
91
  pose_guider=pose_guider,
92
  scheduler=scheduler,
93
  )
94
- pipe = pipe.to("cuda", dtype=weight_dtype)
95
 
96
  lmk_extractor = LMKExtractor()
97
  vis = FaceMeshVisualizer()
@@ -130,7 +137,7 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
130
  ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
131
 
132
  sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
133
- sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
134
  sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
135
 
136
  # inference
 
31
  from src.audio2vid import get_headpose_temp, smooth_pose_seq
32
  from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
33
 
34
+ if torch.backends.mps.is_available():
35
+ device = "mps"
36
+ elif torch.cuda.is_available():
37
+ device = "cuda"
38
+ else:
39
+ device = "cpu"
40
+
41
 
42
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
43
  if config.weight_dtype == "fp16":
 
49
  # prepare model
50
  a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
51
  a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
52
+ a2m_model.to(device).eval()
53
 
54
  vae = AutoencoderKL.from_pretrained(
55
  config.pretrained_vae_path,
56
+ ).to(device, dtype=weight_dtype)
57
 
58
  reference_unet = UNet2DConditionModel.from_pretrained(
59
  config.pretrained_base_model_path,
60
  subfolder="unet",
61
+ ).to(dtype=weight_dtype, device=device)
62
 
63
  inference_config_path = config.inference_config
64
  infer_config = OmegaConf.load(inference_config_path)
 
67
  config.motion_module_path,
68
  subfolder="unet",
69
  unet_additional_kwargs=infer_config.unet_additional_kwargs,
70
+ ).to(dtype=weight_dtype, device=device)
71
 
72
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device=device, dtype=weight_dtype) # not use cross attention
73
 
74
  image_enc = CLIPVisionModelWithProjection.from_pretrained(
75
  config.image_encoder_path
76
+ ).to(dtype=weight_dtype, device=device)
77
 
78
  sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
79
  scheduler = DDIMScheduler(**sched_kwargs)
 
98
  pose_guider=pose_guider,
99
  scheduler=scheduler,
100
  )
101
+ pipe = pipe.to(device, dtype=weight_dtype)
102
 
103
  lmk_extractor = LMKExtractor()
104
  vis = FaceMeshVisualizer()
 
137
  ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
138
 
139
  sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
140
+ sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().to(device)
141
  sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
142
 
143
  # inference
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  accelerate==0.21.0
2
  av==11.0.0
3
  clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
4
- decord==0.6.0
5
  diffusers==0.24.0
6
  einops==0.4.1
7
  gradio==4.24.0
@@ -10,7 +10,7 @@ imageio==2.33.0
10
  imageio-ffmpeg==0.4.9
11
  numpy==1.24.4
12
  omegaconf==2.2.3
13
- onnxruntime-gpu==1.16.3
14
  open-clip-torch==2.20.0
15
  opencv-contrib-python==4.8.1.78
16
  opencv-python==4.8.1.78
@@ -18,15 +18,16 @@ Pillow==9.5.0
18
  scikit-image==0.21.0
19
  scikit-learn==1.3.2
20
  scipy==1.11.4
21
- torch==2.0.1
22
  torchdiffeq==0.2.3
23
  torchmetrics==1.2.1
24
  torchsde==0.2.5
25
- torchvision==0.15.2
26
  tqdm==4.66.1
27
  transformers==4.30.2
28
- xformers==0.0.22
29
  controlnet-aux==0.0.7
30
- mediapipe==0.10.11
 
31
  librosa==0.9.2
32
  ffmpeg-python==0.2.0
 
1
  accelerate==0.21.0
2
  av==11.0.0
3
  clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
4
+ #decord==0.6.0
5
  diffusers==0.24.0
6
  einops==0.4.1
7
  gradio==4.24.0
 
10
  imageio-ffmpeg==0.4.9
11
  numpy==1.24.4
12
  omegaconf==2.2.3
13
+ #onnxruntime-gpu==1.16.3
14
  open-clip-torch==2.20.0
15
  opencv-contrib-python==4.8.1.78
16
  opencv-python==4.8.1.78
 
18
  scikit-image==0.21.0
19
  scikit-learn==1.3.2
20
  scipy==1.11.4
21
+ #torch==2.0.1
22
  torchdiffeq==0.2.3
23
  torchmetrics==1.2.1
24
  torchsde==0.2.5
25
+ #torchvision==0.15.2
26
  tqdm==4.66.1
27
  transformers==4.30.2
28
+ #xformers==0.0.22
29
  controlnet-aux==0.0.7
30
+ #mediapipe==0.10.11
31
+ mediapipe==0.10.9
32
  librosa==0.9.2
33
  ffmpeg-python==0.2.0