import torchvision from einops import rearrange import numpy as np import math import torchaudio import torch import importlib from data_utils import create_masks_from_landmarks_box import torch.nn.functional as F def save_audio_video( video, audio=None, frame_rate=25, sample_rate=16000, save_path="temp.mp4", ): """Save audio and video to a single file. video: (t, c, h, w) audio: (channels t) """ save_path = str(save_path) if isinstance(video, torch.Tensor): video = video.cpu().numpy() video_tensor = rearrange(video, "t c h w -> t h w c").astype(np.uint8) print("video_tensor shape", video_tensor.shape) print("audio shape", audio.shape) if audio is not None: # Assuming audio is a tensor of shape (channels, samples) audio_tensor = audio torchvision.io.write_video( save_path, video_tensor, fps=frame_rate, audio_array=audio_tensor, audio_fps=sample_rate, video_codec="h264", # Specify a codec to address the error audio_codec="aac", ) else: torchvision.io.write_video( save_path, video_tensor, fps=frame_rate, video_codec="h264", # Specify a codec to address the error audio_codec="aac", ) return save_path def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None): len_file = audio.shape[-1] if max_len_sec or max_len_raw: max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr) if len_file < int(max_len): # dummy = np.zeros((1, int(max_len_sec * sr) - len_file)) # extened_wav = np.concatenate((audio_data, dummy[0])) extened_wav = torch.nn.functional.pad( audio, (0, int(max_len) - len_file), "constant" ) else: extened_wav = audio[:, : int(max_len)] else: extened_wav = audio return extened_wav def get_raw_audio(audio_path, audio_rate, fps=25): audio, sr = torchaudio.load(audio_path, channels_first=True) if audio.shape[0] > 1: audio = audio.mean(0, keepdim=True) audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=audio_rate)[0] samples_per_frame = math.ceil(audio_rate / fps) n_frames = audio.shape[-1] / samples_per_frame if not n_frames.is_integer(): audio = trim_pad_audio( audio, audio_rate, max_len_raw=math.ceil(n_frames) * samples_per_frame ) audio = rearrange(audio, "(f s) -> f s", s=samples_per_frame) return audio def calculate_splits(tensor, min_last_size): # Check the total number of elements in the tensor total_size = tensor.size(1) # size along the second dimension # If total size is less than the minimum size for the last split, return the tensor as a single split if total_size <= min_last_size: return [tensor] # Calculate number of splits and size of each split num_splits = (total_size - min_last_size) // min_last_size + 1 base_size = (total_size - min_last_size) // num_splits # Create split sizes list split_sizes = [base_size] * (num_splits - 1) split_sizes.append( total_size - sum(split_sizes) ) # Ensure the last split has at least min_last_size # Adjust sizes to ensure they sum exactly to total_size sum_sizes = sum(split_sizes) while sum_sizes != total_size: for i in range(num_splits): if sum_sizes < total_size: split_sizes[i] += 1 sum_sizes += 1 if sum_sizes >= total_size: break # Split the tensor splits = torch.split(tensor, split_sizes, dim=1) return splits def make_into_multiple_of(x, multiple, dim=0): """Make the torch tensor into a multiple of the given number.""" if x.shape[dim] % multiple != 0: x = torch.cat( [ x, torch.zeros( *x.shape[:dim], multiple - (x.shape[dim] % multiple), *x.shape[dim + 1 :], ).to(x.device), ], dim=dim, ) return x def default(value, default_value): return default_value if value is None else value def instantiate_from_config(config): if not "target" in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False, invalidate_cache=True): module, cls = string.rsplit(".", 1) if invalidate_cache: importlib.invalidate_caches() if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def load_landmarks( landmarks: np.ndarray, original_size, target_size=(64, 64), nose_index=28, ): """ Load and process facial landmarks to create masks. Args: landmarks: Facial landmarks array original_size: Original size of the video frames index: Index for non-dub mode target_size: Target size for the output mask is_dub: Whether this is for dubbing mode what_mask: Type of mask to create ("full", "box", "heart", "mouth") nose_index: Index of the nose landmark Returns: Processed landmarks mask """ expand_box = 0.0 if len(landmarks.shape) == 2: landmarks = landmarks[None, ...] mask = create_masks_from_landmarks_box( landmarks, (original_size[0], original_size[1]), box_expand=expand_box, nose_index=nose_index, ) mask = F.interpolate(mask.unsqueeze(1).float(), size=target_size, mode="nearest") return mask def create_pipeline_inputs( audio: torch.Tensor, audio_interpolation: torch.Tensor, num_frames: int, video_emb: torch.Tensor, landmarks: np.ndarray, overlap: int = 1, add_zero_flag: bool = False, mask_arms: bool = None, nose_index: int = 28, ): """ Create inputs for the keyframe generation and interpolation pipeline. Args: video: Input video tensor audio: Audio embeddings for keyframe generation audio_interpolation: Audio embeddings for interpolation num_frames: Number of frames per segment video_emb: Optional video embeddings landmarks: Facial landmarks for mask generation overlap: Number of frames to overlap between segments add_zero_flag: Whether to add zero flag every num_frames what_mask: Type of mask to generate ("box" or other options) mask_arms: Optional mask for arms region nose_index: Index of the nose landmark point Returns: Tuple containing all necessary inputs for the pipeline """ audio_interpolation_chunks = [] audio_image_preds = [] gt_chunks = [] gt_keyframes_chunks = [] # Adjustment for overlap to ensure segments are created properly step = num_frames - overlap # Ensure there's at least one step forward on each iteration if step < 1: step = 1 audio_image_preds_idx = [] audio_interp_preds_idx = [] masks_chunks = [] masks_interpolation_chunks = [] for i in range(0, audio.shape[0] - num_frames + 1, step): try: audio[i + num_frames - 1] except IndexError: break # Last chunk is smaller than num_frames segment_end = i + num_frames gt_chunks.append(video_emb[i:segment_end]) masks = load_landmarks( landmarks[i:segment_end], (512, 512), target_size=(64, 64), nose_index=nose_index, ) if mask_arms is not None: masks = np.logical_and( masks, np.logical_not(mask_arms[i:segment_end, None, ...]) ) masks_interpolation_chunks.append(masks) if i not in audio_image_preds_idx: audio_image_preds.append(audio[i]) masks_chunks.append(masks[0]) gt_keyframes_chunks.append(video_emb[i]) audio_image_preds_idx.append(i) if segment_end - 1 not in audio_image_preds_idx: audio_image_preds_idx.append(segment_end - 1) audio_image_preds.append(audio[segment_end - 1]) masks_chunks.append(masks[-1]) gt_keyframes_chunks.append(video_emb[segment_end - 1]) audio_interpolation_chunks.append(audio_interpolation[i:segment_end]) audio_interp_preds_idx.append([i, segment_end - 1]) # If the flag is on, add element 0 every 14 audio elements if add_zero_flag: first_element = audio_image_preds[0] len_audio_image_preds = ( len(audio_image_preds) + (len(audio_image_preds) + 1) % num_frames ) for i in range(0, len_audio_image_preds, num_frames): audio_image_preds.insert(i, first_element) audio_image_preds_idx.insert(i, None) masks_chunks.insert(i, masks_chunks[0]) gt_keyframes_chunks.insert(i, gt_keyframes_chunks[0]) to_remove = [idx is None for idx in audio_image_preds_idx] audio_image_preds_idx_clone = [idx for idx in audio_image_preds_idx] if add_zero_flag: # Remove the added elements from the list audio_image_preds_idx = [ sample for i, sample in zip(to_remove, audio_image_preds_idx) if not i ] interpolation_cond_list = [] for i in range(0, len(audio_image_preds_idx) - 1, overlap if overlap > 0 else 2): interpolation_cond_list.append( [audio_image_preds_idx[i], audio_image_preds_idx[i + 1]] ) # Since we generate num_frames at a time, we need to ensure that the last chunk is of size num_frames # Calculate the number of frames needed to make audio_image_preds a multiple of num_frames frames_needed = (num_frames - (len(audio_image_preds) % num_frames)) % num_frames # Extend from the start of audio_image_preds audio_image_preds = audio_image_preds + [audio_image_preds[-1]] * frames_needed masks_chunks = masks_chunks + [masks_chunks[-1]] * frames_needed gt_keyframes_chunks = ( gt_keyframes_chunks + [gt_keyframes_chunks[-1]] * frames_needed ) to_remove = to_remove + [True] * frames_needed audio_image_preds_idx_clone = ( audio_image_preds_idx_clone + [audio_image_preds_idx_clone[-1]] * frames_needed ) print( f"Added {frames_needed} frames from the start to make audio_image_preds a multiple of {num_frames}" ) # random_cond_idx = np.random.randint(0, len(video_emb)) random_cond_idx = 0 assert len(to_remove) == len(audio_image_preds), ( "to_remove and audio_image_preds must have the same length" ) return ( gt_chunks, gt_keyframes_chunks, audio_interpolation_chunks, audio_image_preds, video_emb[random_cond_idx], masks_chunks, masks_interpolation_chunks, to_remove, audio_interp_preds_idx, audio_image_preds_idx_clone, )