import os import cv2 import numpy as np import torch import torchaudio.functional import torchvision.io from PIL import Image from diffusers import AutoencoderKL, DDIMScheduler from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import randn_tensor from insightface.app import FaceAnalysis from omegaconf import OmegaConf from transformers import CLIPVisionModelWithProjection, Wav2Vec2Model, Wav2Vec2Processor from modules import UNet2DConditionModel, UNet3DConditionModel, VKpsGuider, AudioProjection from pipelines import VExpressPipeline from pipelines.utils import draw_kps_image, save_video from pipelines.utils import retarget_kps def load_reference_net(unet_config_path, reference_net_path, dtype, device): reference_net = UNet2DConditionModel.from_config(unet_config_path).to(dtype=dtype, device=device) reference_net.load_state_dict(torch.load(reference_net_path, map_location="cpu"), strict=False) print(f'Loaded weights of Reference Net from {reference_net_path}.') return reference_net def load_denoising_unet(unet_config_path, denoising_unet_path, motion_module_path, dtype, device): inference_config_path = './inference_v2.yaml' inference_config = OmegaConf.load(inference_config_path) denoising_unet = UNet3DConditionModel.from_config_2d( unet_config_path, unet_additional_kwargs=inference_config.unet_additional_kwargs, ).to(dtype=dtype, device=device) denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False) print(f'Loaded weights of Denoising U-Net from {denoising_unet_path}.') denoising_unet.load_state_dict(torch.load(motion_module_path, map_location="cpu"), strict=False) print(f'Loaded weights of Denoising U-Net Motion Module from {motion_module_path}.') return denoising_unet def load_v_kps_guider(v_kps_guider_path, dtype, device): v_kps_guider = VKpsGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device) v_kps_guider.load_state_dict(torch.load(v_kps_guider_path, map_location="cpu")) print(f'Loaded weights of V-Kps Guider from {v_kps_guider_path}.') return v_kps_guider def load_audio_projection( audio_projection_path, dtype, device, inp_dim: int, mid_dim: int, out_dim: int, inp_seq_len: int, out_seq_len: int, ): audio_projection = AudioProjection( dim=mid_dim, depth=4, dim_head=64, heads=12, num_queries=out_seq_len, embedding_dim=inp_dim, output_dim=out_dim, ff_mult=4, max_seq_len=inp_seq_len, ).to(dtype=dtype, device=device) audio_projection.load_state_dict(torch.load(audio_projection_path, map_location='cpu')) print(f'Loaded weights of Audio Projection from {audio_projection_path}.') return audio_projection def get_scheduler(): inference_config_path = './inference_v2.yaml' inference_config = OmegaConf.load(inference_config_path) scheduler_kwargs = OmegaConf.to_container(inference_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**scheduler_kwargs) return scheduler class InferenceEngine(object): def __init__(self, args): self.init_params(args) self.load_models() self.set_generator() self.set_vexpress_pipeline() self.set_face_analysis_app() def init_params(self, args): for key, value in args.items(): setattr(self, key, value) print("Image width: ", self.image_width) print("Image height: ", self.image_height) def load_models(self): self.device = torch.device(f'cuda:{self.gpu_id}') self.dtype = torch.float16 if self.dtype == 'fp16' else torch.float32 self.vae = AutoencoderKL.from_pretrained(self.vae_path).to(dtype=self.dtype, device=self.device) print("VAE exists: ", self.vae) self.audio_encoder = Wav2Vec2Model.from_pretrained(self.audio_encoder_path).to(dtype=self.dtype, device=self.device) self.audio_processor = Wav2Vec2Processor.from_pretrained(self.audio_encoder_path) self.scheduler = get_scheduler() self.reference_net = load_reference_net(self.unet_config_path, self.reference_net_path, self.dtype, self.device) self.denoising_unet = load_denoising_unet(self.unet_config_path, self.denoising_unet_path, self.motion_module_path, self.dtype, self.device) self.v_kps_guider = load_v_kps_guider(self.v_kps_guider_path, self.dtype, self.device) self.audio_projection = load_audio_projection( self.audio_projection_path, self.dtype, self.device, inp_dim=self.denoising_unet.config.cross_attention_dim, mid_dim=self.denoising_unet.config.cross_attention_dim, out_dim=self.denoising_unet.config.cross_attention_dim, inp_seq_len=2 * (2 * self.num_pad_audio_frames + 1), out_seq_len=2 * self.num_pad_audio_frames + 1, ) if is_xformers_available(): self.reference_net.enable_xformers_memory_efficient_attention() self.denoising_unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") def set_generator(self): self.generator = torch.manual_seed(self.seed) def set_vexpress_pipeline(self): print("VAE exists (2): ", self.vae) self.pipeline = VExpressPipeline( vae=self.vae, reference_net=self.reference_net, denoising_unet=self.denoising_unet, v_kps_guider=self.v_kps_guider, audio_processor=self.audio_processor, audio_encoder=self.audio_encoder, audio_projection=self.audio_projection, scheduler=self.scheduler, ).to(dtype=self.dtype, device=self.device) def set_face_analysis_app(self): self.app = FaceAnalysis( providers=['CUDAExecutionProvider'], provider_options=[{'device_id': self.gpu_id}], root=self.insightface_model_path, ) self.app.prepare(ctx_id=0, det_size=(self.image_height, self.image_width)) def get_reference_image_for_kps(self, reference_image_path): reference_image = Image.open(reference_image_path).convert('RGB') print("Image width ???", self.image_width) reference_image = reference_image.resize((self.image_height, self.image_width)) reference_image_for_kps = cv2.imread(reference_image_path) reference_image_for_kps = cv2.resize(reference_image_for_kps, (self.image_height, self.image_width)) reference_kps = self.app.get(reference_image_for_kps)[0].kps[:3] return reference_image, reference_image_for_kps, reference_kps def get_waveform_video_length(self, audio_path): _, audio_waveform, meta_info = torchvision.io.read_video(audio_path, pts_unit='sec') audio_sampling_rate = meta_info['audio_fps'] print(f'Length of audio is {audio_waveform.shape[1]} with the sampling rate of {audio_sampling_rate}.') if audio_sampling_rate != self.standard_audio_sampling_rate: audio_waveform = torchaudio.functional.resample( audio_waveform, orig_freq=audio_sampling_rate, new_freq=self.standard_audio_sampling_rate, ) audio_waveform = audio_waveform.mean(dim=0) duration = audio_waveform.shape[0] / self.standard_audio_sampling_rate video_length = int(duration * self.fps) print(f'The corresponding video length is {video_length}.') return audio_waveform, video_length def get_kps_sequence(self, kps_path, reference_kps, video_length, retarget_strategy): if kps_path != "": assert os.path.exists(kps_path), f'{kps_path} does not exist' kps_sequence = torch.tensor(torch.load(kps_path)) # [len, 3, 2] print(f'The original length of kps sequence is {kps_sequence.shape[0]}.') kps_sequence = torch.nn.functional.interpolate(kps_sequence.permute(1, 2, 0), size=video_length, mode='linear') kps_sequence = kps_sequence.permute(2, 0, 1) print(f'The interpolated length of kps sequence is {kps_sequence.shape[0]}.') if retarget_strategy == 'fix_face': kps_sequence = torch.tensor([reference_kps] * video_length) elif retarget_strategy == 'no_retarget': kps_sequence = kps_sequence elif retarget_strategy == 'offset_retarget': kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=True) elif retarget_strategy == 'naive_retarget': kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=False) else: raise ValueError(f'The retarget strategy {retarget_strategy} is not supported.') return kps_sequence def get_kps_images(self, kps_sequence, reference_image_for_kps, video_length): kps_images = [] for i in range(video_length): kps_image = np.zeros_like(reference_image_for_kps) kps_image = draw_kps_image(kps_image, kps_sequence[i]) kps_images.append(Image.fromarray(kps_image)) return kps_images def get_video_latents(self, reference_image, kps_images, audio_waveform, video_length, reference_attention_weight, audio_attention_weight): vae_scale_factor = 8 latent_height = self.image_height // vae_scale_factor latent_width = self.image_width // vae_scale_factor latent_shape = (1, 4, video_length, latent_height, latent_width) vae_latents = randn_tensor(latent_shape, generator=self.generator, device=self.device, dtype=self.dtype) video_latents = self.pipeline( vae_latents=vae_latents, reference_image=reference_image, kps_images=kps_images, audio_waveform=audio_waveform, width=self.image_width, height=self.image_height, video_length=video_length, num_inference_steps=self.num_inference_steps, guidance_scale=self.guidance_scale, context_frames=self.context_frames, context_stride=self.context_stride, context_overlap=self.context_overlap, reference_attention_weight=reference_attention_weight, audio_attention_weight=audio_attention_weight, num_pad_audio_frames=self.num_pad_audio_frames, generator=self.generator, ).video_latents return video_latents def get_video_tensor(self, video_latents): video_tensor = self.pipeline.decode_latents(video_latents) if isinstance(video_tensor, np.ndarray): video_tensor = torch.from_numpy(video_tensor) return video_tensor def save_video_tensor(self, video_tensor, audio_path, output_path): save_video(video_tensor, audio_path, output_path, self.fps) print(f'The generated video has been saved at {output_path}.') def infer( self, reference_image_path, audio_path, kps_path, output_path, retarget_strategy, reference_attention_weight, audio_attention_weight): reference_image, reference_image_for_kps, reference_kps = self.get_reference_image_for_kps(reference_image_path) audio_waveform, video_length = self.get_waveform_video_length(audio_path) kps_sequence = self.get_kps_sequence(kps_path, reference_kps, video_length, retarget_strategy) kps_images = self.get_kps_images(kps_sequence, reference_image_for_kps, video_length) video_latents = self.get_video_latents( reference_image, kps_images, audio_waveform, video_length, reference_attention_weight, audio_attention_weight) video_tensor = self.get_video_tensor(video_latents) self.save_video_tensor(video_tensor, audio_path, output_path)