AniPortrait_official / src /audio2vid.py
zejunyang
update
8749423
import os
import ffmpeg
from datetime import datetime
from pathlib import Path
import numpy as np
import cv2
# import torch
# import spaces
from scipy.spatial.transform import Rotation as R
from scipy.interpolate import interp1d
# from diffusers import AutoencoderKL, DDIMScheduler
# from einops import repeat
# from omegaconf import OmegaConf
# from PIL import Image
# from torchvision import transforms
# from transformers import CLIPVisionModelWithProjection
# from src.models.pose_guider import PoseGuider
# from src.models.unet_2d_condition import UNet2DConditionModel
# from src.models.unet_3d import UNet3DConditionModel
# from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
# from src.utils.util import save_videos_grid
# from src.audio_models.model import Audio2MeshModel
# from src.utils.audio_util import prepare_audio_feature
from src.utils.mp_utils import LMKExtractor
# from src.utils.draw_util import FaceMeshVisualizer
# from src.utils.pose_util import project_points
# from src.utils.crop_face_single import crop_face
def matrix_to_euler_and_translation(matrix):
rotation_matrix = matrix[:3, :3]
translation_vector = matrix[:3, 3]
rotation = R.from_matrix(rotation_matrix)
euler_angles = rotation.as_euler('xyz', degrees=True)
return euler_angles, translation_vector
def smooth_pose_seq(pose_seq, window_size=5):
smoothed_pose_seq = np.zeros_like(pose_seq)
for i in range(len(pose_seq)):
start = max(0, i - window_size // 2)
end = min(len(pose_seq), i + window_size // 2 + 1)
smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0)
return smoothed_pose_seq
def get_headpose_temp(input_video):
lmk_extractor = LMKExtractor()
cap = cv2.VideoCapture(input_video)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
trans_mat_list = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
result = lmk_extractor(frame)
trans_mat_list.append(result['trans_mat'].astype(np.float32))
cap.release()
trans_mat_arr = np.array(trans_mat_list)
# compute delta pose
trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0])
pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
for i in range(pose_arr.shape[0]):
pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i]
euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat)
pose_arr[i, :3] = euler_angles
pose_arr[i, 3:6] = translation_vector
# interpolate to 30 fps
new_fps = 30
old_time = np.linspace(0, total_frames / fps, total_frames)
new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps))
pose_arr_interp = np.zeros((len(new_time), 6))
for i in range(6):
interp_func = interp1d(old_time, pose_arr[:, i])
pose_arr_interp[:, i] = interp_func(new_time)
pose_arr_smooth = smooth_pose_seq(pose_arr_interp)
return pose_arr_smooth
# @spaces.GPU(duration=150)
# def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
# fps = 30
# cfg = 3.5
# config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
# if config.weight_dtype == "fp16":
# weight_dtype = torch.float16
# else:
# weight_dtype = torch.float32
# audio_infer_config = OmegaConf.load(config.audio_inference_config)
# # prepare model
# a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
# a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
# a2m_model.cuda().eval()
# vae = AutoencoderKL.from_pretrained(
# config.pretrained_vae_path,
# ).to("cuda", dtype=weight_dtype)
# reference_unet = UNet2DConditionModel.from_pretrained(
# config.pretrained_base_model_path,
# subfolder="unet",
# ).to(dtype=weight_dtype, device="cuda")
# inference_config_path = config.inference_config
# infer_config = OmegaConf.load(inference_config_path)
# denoising_unet = UNet3DConditionModel.from_pretrained_2d(
# config.pretrained_base_model_path,
# config.motion_module_path,
# subfolder="unet",
# unet_additional_kwargs=infer_config.unet_additional_kwargs,
# ).to(dtype=weight_dtype, device="cuda")
# pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
# image_enc = CLIPVisionModelWithProjection.from_pretrained(
# config.image_encoder_path
# ).to(dtype=weight_dtype, device="cuda")
# sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
# scheduler = DDIMScheduler(**sched_kwargs)
# generator = torch.manual_seed(seed)
# width, height = size, size
# # load pretrained weights
# denoising_unet.load_state_dict(
# torch.load(config.denoising_unet_path, map_location="cpu"),
# strict=False,
# )
# reference_unet.load_state_dict(
# torch.load(config.reference_unet_path, map_location="cpu"),
# )
# pose_guider.load_state_dict(
# torch.load(config.pose_guider_path, map_location="cpu"),
# )
# pipe = Pose2VideoPipeline(
# vae=vae,
# image_encoder=image_enc,
# reference_unet=reference_unet,
# denoising_unet=denoising_unet,
# pose_guider=pose_guider,
# scheduler=scheduler,
# )
# pipe = pipe.to("cuda", dtype=weight_dtype)
# date_str = datetime.now().strftime("%Y%m%d")
# time_str = datetime.now().strftime("%H%M")
# save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
# save_dir = Path(f"output/{date_str}/{save_dir_name}")
# save_dir.mkdir(exist_ok=True, parents=True)
# lmk_extractor = LMKExtractor()
# vis = FaceMeshVisualizer(forehead_edge=False)
# ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
# ref_image_np = crop_face(ref_image_np, lmk_extractor)
# if ref_image_np is None:
# return None, Image.fromarray(ref_img)
# ref_image_np = cv2.resize(ref_image_np, (size, size))
# ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
# face_result = lmk_extractor(ref_image_np)
# if face_result is None:
# return None, ref_image_pil
# lmks = face_result['lmks'].astype(np.float32)
# ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
# sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
# sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
# sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
# # inference
# pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
# pred = pred.squeeze().detach().cpu().numpy()
# pred = pred.reshape(pred.shape[0], -1, 3)
# pred = pred + face_result['lmks3d']
# if headpose_video is not None:
# pose_seq = get_headpose_temp(headpose_video)
# else:
# pose_seq = np.load(config['pose_temp'])
# mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
# cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
# # project 3D mesh to 2D landmark
# projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width])
# pose_images = []
# for i, verts in enumerate(projected_vertices):
# lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
# pose_images.append(lmk_img)
# pose_list = []
# pose_tensor_list = []
# pose_transform = transforms.Compose(
# [transforms.Resize((height, width)), transforms.ToTensor()]
# )
# args_L = len(pose_images) if length==0 or length > len(pose_images) else length
# args_L = min(args_L, 300)
# for pose_image_np in pose_images[: args_L]:
# pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
# pose_tensor_list.append(pose_transform(pose_image_pil))
# pose_image_np = cv2.resize(pose_image_np, (width, height))
# pose_list.append(pose_image_np)
# pose_list = np.array(pose_list)
# video_length = len(pose_tensor_list)
# video = pipe(
# ref_image_pil,
# pose_list,
# ref_pose,
# width,
# height,
# video_length,
# steps,
# cfg,
# generator=generator,
# ).videos
# save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
# save_videos_grid(
# video,
# save_path,
# n_rows=1,
# fps=fps,
# )
# stream = ffmpeg.input(save_path)
# audio = ffmpeg.input(input_audio)
# ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
# os.remove(save_path)
# return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil