|
import os |
|
import ffmpeg |
|
from datetime import datetime |
|
from pathlib import Path |
|
import numpy as np |
|
import cv2 |
|
import torch |
|
import spaces |
|
from scipy.spatial.transform import Rotation as R |
|
from scipy.interpolate import interp1d |
|
|
|
|
|
|
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from src.utils.util import save_videos_grid |
|
|
|
|
|
from src.utils.audio_util import prepare_audio_feature |
|
|
|
|
|
from src.utils.pose_util import project_points |
|
from src.utils.crop_face_single import crop_face |
|
from src.create_modules import lmk_extractor, vis, a2m_model, pipe |
|
|
|
|
|
def matrix_to_euler_and_translation(matrix): |
|
rotation_matrix = matrix[:3, :3] |
|
translation_vector = matrix[:3, 3] |
|
rotation = R.from_matrix(rotation_matrix) |
|
euler_angles = rotation.as_euler('xyz', degrees=True) |
|
return euler_angles, translation_vector |
|
|
|
|
|
def smooth_pose_seq(pose_seq, window_size=5): |
|
smoothed_pose_seq = np.zeros_like(pose_seq) |
|
|
|
for i in range(len(pose_seq)): |
|
start = max(0, i - window_size // 2) |
|
end = min(len(pose_seq), i + window_size // 2 + 1) |
|
smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0) |
|
|
|
return smoothed_pose_seq |
|
|
|
def get_headpose_temp(input_video): |
|
|
|
cap = cv2.VideoCapture(input_video) |
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
trans_mat_list = [] |
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
result = lmk_extractor(frame) |
|
trans_mat_list.append(result['trans_mat'].astype(np.float32)) |
|
cap.release() |
|
|
|
trans_mat_arr = np.array(trans_mat_list) |
|
|
|
|
|
trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0]) |
|
pose_arr = np.zeros([trans_mat_arr.shape[0], 6]) |
|
|
|
for i in range(pose_arr.shape[0]): |
|
pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i] |
|
euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat) |
|
pose_arr[i, :3] = euler_angles |
|
pose_arr[i, 3:6] = translation_vector |
|
|
|
|
|
new_fps = 30 |
|
old_time = np.linspace(0, total_frames / fps, total_frames) |
|
new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps)) |
|
|
|
pose_arr_interp = np.zeros((len(new_time), 6)) |
|
for i in range(6): |
|
interp_func = interp1d(old_time, pose_arr[:, i]) |
|
pose_arr_interp[:, i] = interp_func(new_time) |
|
|
|
pose_arr_smooth = smooth_pose_seq(pose_arr_interp) |
|
|
|
return pose_arr_smooth |
|
|
|
@spaces.GPU |
|
def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42): |
|
fps = 30 |
|
cfg = 3.5 |
|
|
|
config = OmegaConf.load('./configs/prompts/animation_audio.yaml') |
|
|
|
|
|
|
|
|
|
|
|
|
|
audio_infer_config = OmegaConf.load(config.audio_inference_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generator = torch.manual_seed(seed) |
|
|
|
width, height = size, size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
date_str = datetime.now().strftime("%Y%m%d") |
|
time_str = datetime.now().strftime("%H%M") |
|
save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}" |
|
|
|
save_dir = Path(f"output/{date_str}/{save_dir_name}") |
|
save_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
|
ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR) |
|
ref_image_np = crop_face(ref_image_np, lmk_extractor) |
|
if ref_image_np is None: |
|
return None, Image.fromarray(ref_img) |
|
|
|
ref_image_np = cv2.resize(ref_image_np, (size, size)) |
|
ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB)) |
|
|
|
face_result = lmk_extractor(ref_image_np) |
|
if face_result is None: |
|
return None, ref_image_pil |
|
|
|
lmks = face_result['lmks'].astype(np.float32) |
|
ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True) |
|
|
|
sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path']) |
|
sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda() |
|
sample['audio_feature'] = sample['audio_feature'].unsqueeze(0) |
|
|
|
|
|
pred = a2m_model.infer(sample['audio_feature'], sample['seq_len']) |
|
pred = pred.squeeze().detach().cpu().numpy() |
|
pred = pred.reshape(pred.shape[0], -1, 3) |
|
pred = pred + face_result['lmks3d'] |
|
|
|
if headpose_video is not None: |
|
pose_seq = get_headpose_temp(headpose_video) |
|
else: |
|
pose_seq = np.load(config['pose_temp']) |
|
mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0) |
|
cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']] |
|
|
|
|
|
projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width]) |
|
|
|
pose_images = [] |
|
for i, verts in enumerate(projected_vertices): |
|
lmk_img = vis.draw_landmarks((width, height), verts, normed=False) |
|
pose_images.append(lmk_img) |
|
|
|
pose_list = [] |
|
pose_tensor_list = [] |
|
|
|
pose_transform = transforms.Compose( |
|
[transforms.Resize((height, width)), transforms.ToTensor()] |
|
) |
|
args_L = len(pose_images) if length==0 or length > len(pose_images) else length |
|
args_L = min(args_L, 300) |
|
for pose_image_np in pose_images[: args_L]: |
|
pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB)) |
|
pose_tensor_list.append(pose_transform(pose_image_pil)) |
|
pose_image_np = cv2.resize(pose_image_np, (width, height)) |
|
pose_list.append(pose_image_np) |
|
|
|
pose_list = np.array(pose_list) |
|
|
|
video_length = len(pose_tensor_list) |
|
|
|
video = pipe( |
|
ref_image_pil, |
|
pose_list, |
|
ref_pose, |
|
width, |
|
height, |
|
video_length, |
|
steps, |
|
cfg, |
|
generator=generator, |
|
).videos |
|
|
|
save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4" |
|
save_videos_grid( |
|
video, |
|
save_path, |
|
n_rows=1, |
|
fps=fps, |
|
) |
|
|
|
stream = ffmpeg.input(save_path) |
|
audio = ffmpeg.input(input_audio) |
|
ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run() |
|
os.remove(save_path) |
|
|
|
return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil |
|
|