Spaces:
Sleeping
Sleeping
| import argparse | |
| import logging | |
| import os | |
| import torch | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from omegaconf import OmegaConf | |
| from packaging import version | |
| from tqdm import tqdm | |
| from memo.models.audio_proj import AudioProjModel | |
| from memo.models.image_proj import ImageProjModel | |
| from memo.models.unet_2d_condition import UNet2DConditionModel | |
| from memo.models.unet_3d import UNet3DConditionModel | |
| from memo.pipelines.video_pipeline import VideoPipeline | |
| from memo.utils.audio_utils import extract_audio_emotion_labels, preprocess_audio, resample_audio | |
| from memo.utils.vision_utils import preprocess_image, tensor_to_video | |
| logger = logging.getLogger("memo") | |
| logger.setLevel(logging.INFO) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Inference script for MEMO") | |
| parser.add_argument("--config", type=str, default="configs/inference.yaml") | |
| parser.add_argument("--input_image", type=str) | |
| parser.add_argument("--input_audio", type=str) | |
| parser.add_argument("--output_dir", type=str) | |
| parser.add_argument("--seed", type=int, default=42) | |
| return parser.parse_args() | |
| def main(): | |
| # Parse arguments | |
| args = parse_args() | |
| input_image_path = args.input_image | |
| input_audio_path = args.input_audio | |
| if "wav" not in input_audio_path: | |
| logger.warning("MEMO might not generate full-length video for non-wav audio file.") | |
| output_dir = args.output_dir | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_video_path = os.path.join( | |
| output_dir, | |
| f"{os.path.basename(input_image_path).split('.')[0]}_{os.path.basename(input_audio_path).split('.')[0]}.mp4", | |
| ) | |
| if os.path.exists(output_video_path): | |
| logger.info(f"Output file {output_video_path} already exists. Skipping inference.") | |
| return | |
| generator = torch.manual_seed(args.seed) | |
| logger.info(f"Loading config from {args.config}") | |
| config = OmegaConf.load(args.config) | |
| # Determine model paths | |
| if config.model_name_or_path == "memoavatar/memo": | |
| logger.info( | |
| f"The MEMO model will be downloaded from Hugging Face to the default cache directory. The models for face analysis and vocal separation will be downloaded to {config.misc_model_dir}." | |
| ) | |
| face_analysis = os.path.join(config.misc_model_dir, "misc/face_analysis") | |
| os.makedirs(face_analysis, exist_ok=True) | |
| for model in [ | |
| "1k3d68.onnx", | |
| "2d106det.onnx", | |
| "face_landmarker_v2_with_blendskapes.task", | |
| "genderage.onnx", | |
| "glintr100.onnx", | |
| "scrfd_10g_bnkps.onnx", | |
| ]: | |
| if not os.path.exists(os.path.join(face_analysis, model)): | |
| logger.info(f"Downloading {model} to {face_analysis}") | |
| os.system( | |
| f"wget -P {face_analysis} https://huggingface.co/memoavatar/memo/raw/main/misc/face_analysis/models/{model}" | |
| ) | |
| logger.info(f"Use face analysis models from {face_analysis}") | |
| vocal_separator = os.path.join(config.misc_model_dir, "misc/vocal_separator/Kim_Vocal_2.onnx") | |
| if os.path.exists(vocal_separator): | |
| logger.info(f"Vocal separator {vocal_separator} already exists. Skipping download.") | |
| else: | |
| logger.info(f"Downloading vocal separator to {vocal_separator}") | |
| os.makedirs(os.path.dirname(vocal_separator), exist_ok=True) | |
| os.system( | |
| f"wget -P {os.path.dirname(vocal_separator)} https://huggingface.co/memoavatar/memo/raw/main/misc/vocal_separator/Kim_Vocal_2.onnx" | |
| ) | |
| else: | |
| logger.info(f"Loading manually specified model path: {config.model_name_or_path}") | |
| face_analysis = os.path.join(config.model_name_or_path, "misc/face_analysis") | |
| vocal_separator = os.path.join(config.model_name_or_path, "misc/vocal_separator/Kim_Vocal_2.onnx") | |
| # Set up device and weight dtype | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| if config.weight_dtype == "fp16": | |
| weight_dtype = torch.float16 | |
| elif config.weight_dtype == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| elif config.weight_dtype == "fp32": | |
| weight_dtype = torch.float32 | |
| else: | |
| weight_dtype = torch.float32 | |
| logger.info(f"Inference dtype: {weight_dtype}") | |
| logger.info(f"Processing image {input_image_path}") | |
| img_size = (config.resolution, config.resolution) | |
| pixel_values, face_emb = preprocess_image( | |
| face_analysis_model=face_analysis, | |
| image_path=input_image_path, | |
| image_size=config.resolution, | |
| ) | |
| logger.info(f"Processing audio {input_audio_path}") | |
| cache_dir = os.path.join(output_dir, "audio_preprocess") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| input_audio_path = resample_audio( | |
| input_audio_path, | |
| os.path.join(cache_dir, f"{os.path.basename(input_audio_path).split('.')[0]}-16k.wav"), | |
| ) | |
| audio_emb, audio_length = preprocess_audio( | |
| wav_path=input_audio_path, | |
| num_generated_frames_per_clip=config.num_generated_frames_per_clip, | |
| fps=config.fps, | |
| wav2vec_model=config.wav2vec, | |
| vocal_separator_model=vocal_separator, | |
| cache_dir=cache_dir, | |
| device=device, | |
| ) | |
| logger.info("Processing audio emotion") | |
| audio_emotion, num_emotion_classes = extract_audio_emotion_labels( | |
| model=config.model_name_or_path, | |
| wav_path=input_audio_path, | |
| emotion2vec_model=config.emotion2vec, | |
| audio_length=audio_length, | |
| device=device, | |
| ) | |
| logger.info("Loading models") | |
| vae = AutoencoderKL.from_pretrained(config.vae).to(device=device, dtype=weight_dtype) | |
| reference_net = UNet2DConditionModel.from_pretrained( | |
| config.model_name_or_path, subfolder="reference_net", use_safetensors=True | |
| ) | |
| diffusion_net = UNet3DConditionModel.from_pretrained( | |
| config.model_name_or_path, subfolder="diffusion_net", use_safetensors=True | |
| ) | |
| image_proj = ImageProjModel.from_pretrained( | |
| config.model_name_or_path, subfolder="image_proj", use_safetensors=True | |
| ) | |
| audio_proj = AudioProjModel.from_pretrained( | |
| config.model_name_or_path, subfolder="audio_proj", use_safetensors=True | |
| ) | |
| vae.requires_grad_(False).eval() | |
| reference_net.requires_grad_(False).eval() | |
| diffusion_net.requires_grad_(False).eval() | |
| image_proj.requires_grad_(False).eval() | |
| audio_proj.requires_grad_(False).eval() | |
| # Enable memory-efficient attention for xFormers | |
| if config.enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| import xformers | |
| xformers_version = version.parse(xformers.__version__) | |
| if xformers_version == version.parse("0.0.16"): | |
| logger.info( | |
| "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | |
| ) | |
| reference_net.enable_xformers_memory_efficient_attention() | |
| diffusion_net.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| # Create inference pipeline | |
| noise_scheduler = FlowMatchEulerDiscreteScheduler() | |
| pipeline = VideoPipeline( | |
| vae=vae, | |
| reference_net=reference_net, | |
| diffusion_net=diffusion_net, | |
| scheduler=noise_scheduler, | |
| image_proj=image_proj, | |
| ) | |
| pipeline.to(device=device, dtype=weight_dtype) | |
| video_frames = [] | |
| num_clips = audio_emb.shape[0] // config.num_generated_frames_per_clip | |
| for t in tqdm(range(num_clips), desc="Generating video clips"): | |
| if len(video_frames) == 0: | |
| # Initialize the first past frames with reference image | |
| past_frames = pixel_values.repeat(config.num_init_past_frames, 1, 1, 1) | |
| past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device) | |
| pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0) | |
| else: | |
| past_frames = video_frames[-1][0] | |
| past_frames = past_frames.permute(1, 0, 2, 3) | |
| past_frames = past_frames[0 - config.num_past_frames :] | |
| past_frames = past_frames * 2.0 - 1.0 | |
| past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device) | |
| pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0) | |
| pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) | |
| audio_tensor = ( | |
| audio_emb[ | |
| t | |
| * config.num_generated_frames_per_clip : min( | |
| (t + 1) * config.num_generated_frames_per_clip, audio_emb.shape[0] | |
| ) | |
| ] | |
| .unsqueeze(0) | |
| .to(device=audio_proj.device, dtype=audio_proj.dtype) | |
| ) | |
| audio_tensor = audio_proj(audio_tensor) | |
| audio_emotion_tensor = audio_emotion[ | |
| t | |
| * config.num_generated_frames_per_clip : min( | |
| (t + 1) * config.num_generated_frames_per_clip, audio_emb.shape[0] | |
| ) | |
| ] | |
| pipeline_output = pipeline( | |
| ref_image=pixel_values_ref_img, | |
| audio_tensor=audio_tensor, | |
| audio_emotion=audio_emotion_tensor, | |
| emotion_class_num=num_emotion_classes, | |
| face_emb=face_emb, | |
| width=img_size[0], | |
| height=img_size[1], | |
| video_length=config.num_generated_frames_per_clip, | |
| num_inference_steps=config.inference_steps, | |
| guidance_scale=config.cfg_scale, | |
| generator=generator, | |
| ) | |
| video_frames.append(pipeline_output.videos) | |
| video_frames = torch.cat(video_frames, dim=2) | |
| video_frames = video_frames.squeeze(0) | |
| video_frames = video_frames[:, :audio_length] | |
| tensor_to_video(video_frames, output_video_path, input_audio_path, fps=config.fps) | |
| if __name__ == "__main__": | |
| main() | |