Spaces:
Running
Running
| import random | |
| import numpy as np | |
| from functools import partial | |
| from torch.utils.data import Dataset, WeightedRandomSampler | |
| import torch.nn.functional as F | |
| import torch | |
| import math | |
| import decord | |
| from einops import rearrange | |
| from more_itertools import sliding_window | |
| from omegaconf import ListConfig | |
| import torchaudio | |
| import soundfile as sf | |
| from torchvision.transforms import RandomHorizontalFlip | |
| from audiomentations import Compose, AddGaussianNoise, PitchShift | |
| from safetensors.torch import load_file | |
| from tqdm import tqdm | |
| import cv2 | |
| from sgm.data.data_utils import ( | |
| create_masks_from_landmarks_full_size, | |
| create_face_mask_from_landmarks, | |
| create_masks_from_landmarks_box, | |
| create_masks_from_landmarks_mouth, | |
| ) | |
| from sgm.data.mask import face_mask_cheeks_batch | |
| torchaudio.set_audio_backend("sox_io") | |
| decord.bridge.set_bridge("torch") | |
| def exists(x): | |
| return x is not None | |
| def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None): | |
| len_file = audio.shape[-1] | |
| if max_len_sec or max_len_raw: | |
| max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr) | |
| if len_file < int(max_len): | |
| extened_wav = torch.nn.functional.pad( | |
| audio, (0, int(max_len) - len_file), "constant" | |
| ) | |
| else: | |
| extened_wav = audio[:, : int(max_len)] | |
| else: | |
| extened_wav = audio | |
| return extened_wav | |
| # Similar to regular video dataset but trades flexibility for speed | |
| class VideoDataset(Dataset): | |
| def __init__( | |
| self, | |
| filelist, | |
| resize_size=None, | |
| audio_folder="Audio", | |
| video_folder="CroppedVideos", | |
| emotions_folder="emotions", | |
| landmarks_folder=None, | |
| audio_emb_folder=None, | |
| video_extension=".avi", | |
| audio_extension=".wav", | |
| audio_rate=16000, | |
| latent_folder=None, | |
| audio_in_video=False, | |
| fps=25, | |
| num_frames=5, | |
| need_cond=True, | |
| step=1, | |
| mode="prediction", | |
| scale_audio=False, | |
| augment=False, | |
| augment_audio=False, | |
| use_latent=False, | |
| latent_type="stable", | |
| latent_scale=1, # For backwards compatibility | |
| from_audio_embedding=False, | |
| load_all_possible_indexes=False, | |
| audio_emb_type="wavlm", | |
| cond_noise=[-3.0, 0.5], | |
| motion_id=255.0, | |
| data_mean=None, | |
| data_std=None, | |
| use_latent_condition=False, | |
| skip_frames=0, | |
| get_separate_id=False, | |
| virtual_increase=1, | |
| filter_by_length=False, | |
| select_randomly=False, | |
| balance_datasets=True, | |
| use_emotions=False, | |
| get_original_frames=False, | |
| add_extra_audio_emb=False, | |
| expand_box=0.0, | |
| nose_index=28, | |
| what_mask="full", | |
| get_masks=False, | |
| ): | |
| self.audio_folder = audio_folder | |
| self.from_audio_embedding = from_audio_embedding | |
| self.audio_emb_type = audio_emb_type | |
| self.cond_noise = cond_noise | |
| self.latent_condition = use_latent_condition | |
| precomputed_latent = latent_type | |
| self.audio_emb_folder = ( | |
| audio_emb_folder if audio_emb_folder is not None else audio_folder | |
| ) | |
| self.skip_frames = skip_frames | |
| self.get_separate_id = get_separate_id | |
| self.fps = fps | |
| self.virtual_increase = virtual_increase | |
| self.select_randomly = select_randomly | |
| self.use_emotions = use_emotions | |
| self.emotions_folder = emotions_folder | |
| self.get_original_frames = get_original_frames | |
| self.add_extra_audio_emb = add_extra_audio_emb | |
| self.expand_box = expand_box | |
| self.nose_index = nose_index | |
| self.landmarks_folder = landmarks_folder | |
| self.what_mask = what_mask | |
| self.get_masks = get_masks | |
| assert not (exists(data_mean) ^ exists(data_std)), ( | |
| "Both data_mean and data_std should be provided" | |
| ) | |
| if data_mean is not None: | |
| data_mean = rearrange(torch.as_tensor(data_mean), "c -> c () () ()") | |
| data_std = rearrange(torch.as_tensor(data_std), "c -> c () () ()") | |
| self.data_mean = data_mean | |
| self.data_std = data_std | |
| self.motion_id = motion_id | |
| self.latent_folder = ( | |
| latent_folder if latent_folder is not None else video_folder | |
| ) | |
| self.audio_in_video = audio_in_video | |
| self.filelist = [] | |
| self.audio_filelist = [] | |
| self.landmark_filelist = [] if get_masks else None | |
| with open(filelist, "r") as files: | |
| for f in files.readlines(): | |
| f = f.rstrip() | |
| audio_path = f.replace(video_folder, audio_folder).replace( | |
| video_extension, audio_extension | |
| ) | |
| self.filelist += [f] | |
| self.audio_filelist += [audio_path] | |
| if self.get_masks: | |
| landmark_path = f.replace(video_folder, landmarks_folder).replace( | |
| video_extension, ".npy" | |
| ) | |
| self.landmark_filelist += [landmark_path] | |
| self.resize_size = resize_size | |
| if use_latent and not precomputed_latent: | |
| self.resize_size *= 4 if latent_type in ["stable", "ldm"] else 8 | |
| self.scale_audio = scale_audio | |
| self.step = step | |
| self.use_latent = use_latent | |
| self.precomputed_latent = precomputed_latent | |
| self.latent_type = latent_type | |
| self.latent_scale = latent_scale | |
| self.video_ext = video_extension | |
| self.video_folder = video_folder | |
| self.augment = augment | |
| self.maybe_augment = RandomHorizontalFlip(p=0.5) if augment else lambda x: x | |
| self.maybe_augment_audio = ( | |
| Compose( | |
| [ | |
| AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.002, p=0.25), | |
| # TimeStretch(min_rate=0.8, max_rate=1.25, p=0.3), | |
| PitchShift(min_semitones=-1, max_semitones=1, p=0.25), | |
| # Shift(min_fraction=-0.5, max_fraction=0.5, p=0.333), | |
| ] | |
| ) | |
| if augment_audio | |
| else lambda x, sample_rate: x | |
| ) | |
| self.maybe_augment_audio = partial( | |
| self.maybe_augment_audio, sample_rate=audio_rate | |
| ) | |
| self.mode = mode | |
| if mode == "interpolation": | |
| need_cond = False # Interpolation does not need condition as first and last frame becomes the condition | |
| self.need_cond = need_cond # If need cond will extract one more frame than the number of frames | |
| if get_separate_id: | |
| self.need_cond = True | |
| # It is used for the conditional model when the condition is not on the temporal dimension | |
| num_frames = num_frames if not self.need_cond else num_frames + 1 | |
| vr = decord.VideoReader(self.filelist[0]) | |
| self.video_rate = math.ceil(vr.get_avg_fps()) | |
| print(f"Video rate: {self.video_rate}") | |
| self.audio_rate = audio_rate | |
| a2v_ratio = fps / float(self.audio_rate) | |
| self.samples_per_frame = math.ceil(1 / a2v_ratio) | |
| if get_separate_id: | |
| assert mode == "prediction", ( | |
| "Separate identity frame is only supported for prediction mode" | |
| ) | |
| # No need for extra frame if we are getting a separate identity frame | |
| self.need_cond = True | |
| num_frames -= 1 | |
| self.num_frames = num_frames | |
| self.load_all_possible_indexes = load_all_possible_indexes | |
| if load_all_possible_indexes: | |
| self._indexes = self._get_indexes( | |
| self.filelist, self.audio_filelist, self.landmark_filelist | |
| ) | |
| else: | |
| if filter_by_length: | |
| self._indexes = self.filter_by_length( | |
| self.filelist, self.audio_filelist, self.landmark_filelist | |
| ) | |
| else: | |
| if self.get_masks: | |
| self._indexes = list( | |
| zip(self.filelist, self.audio_filelist, self.landmark_filelist) | |
| ) | |
| else: | |
| self._indexes = list( | |
| zip( | |
| self.filelist, | |
| self.audio_filelist, | |
| [None] * len(self.filelist), | |
| ) | |
| ) | |
| self.balance_datasets = balance_datasets | |
| if self.balance_datasets: | |
| self.weights = self._calculate_weights() | |
| self.sampler = WeightedRandomSampler( | |
| self.weights, num_samples=len(self._indexes), replacement=True | |
| ) | |
| def __len__(self): | |
| return len(self._indexes) * self.virtual_increase | |
| def _load_landmarks(self, filename, original_size, target_size, indexes): | |
| landmarks = np.load(filename, allow_pickle=True)[indexes, :] | |
| if self.what_mask == "full": | |
| mask = create_masks_from_landmarks_full_size( | |
| landmarks, | |
| original_size[0], | |
| original_size[1], | |
| offset=self.expand_box, | |
| nose_index=self.nose_index, | |
| ) | |
| elif self.what_mask == "box": | |
| mask = create_masks_from_landmarks_box( | |
| landmarks, | |
| (original_size[0], original_size[1]), | |
| box_expand=self.expand_box, | |
| nose_index=self.nose_index, | |
| ) | |
| elif self.what_mask == "heart": | |
| mask = face_mask_cheeks_batch( | |
| original_size, landmarks, box_expand=0.0, show_nose=True | |
| ) | |
| elif self.what_mask == "mouth": | |
| mask = create_masks_from_landmarks_mouth( | |
| landmarks, | |
| (original_size[0], original_size[1]), | |
| box_expand=0.01, | |
| nose_index=self.nose_index, | |
| ) | |
| else: | |
| mask = create_face_mask_from_landmarks( | |
| landmarks, original_size[0], original_size[1], mask_expand=0.05 | |
| ) | |
| # Interpolate the mask to the target size | |
| mask = F.interpolate( | |
| mask.unsqueeze(1).float(), size=target_size, mode="nearest" | |
| ) | |
| return mask, landmarks | |
| def get_emotions(self, video_file, video_indexes): | |
| emotions_path = video_file.replace( | |
| self.video_folder, self.emotions_folder | |
| ).replace(self.video_ext, ".pt") | |
| emotions = torch.load(emotions_path) | |
| return ( | |
| emotions["valence"][video_indexes], | |
| emotions["arousal"][video_indexes], | |
| emotions["labels"][video_indexes], | |
| ) | |
| def get_frame_indices(self, total_video_frames, select_randomly=False, start_idx=0): | |
| if select_randomly: | |
| # Randomly select self.num_frames indices from the available range | |
| available_indices = list(range(start_idx, total_video_frames)) | |
| if len(available_indices) < self.num_frames: | |
| raise ValueError( | |
| "Not enough frames in the video to sample with given parameters." | |
| ) | |
| indexes = random.sample(available_indices, self.num_frames) | |
| return sorted(indexes) # Sort to maintain temporal order | |
| else: | |
| # Calculate the maximum possible start index | |
| max_start_idx = total_video_frames - ( | |
| (self.num_frames - 1) * (self.skip_frames + 1) + 1 | |
| ) | |
| # Generate a random start index | |
| if max_start_idx > 0: | |
| start_idx = np.random.randint(start_idx, max_start_idx) | |
| else: | |
| raise ValueError( | |
| "Not enough frames in the video to sample with given parameters." | |
| ) | |
| # Generate the indices | |
| indexes = [ | |
| start_idx + i * (self.skip_frames + 1) for i in range(self.num_frames) | |
| ] | |
| return indexes | |
| def _load_audio(self, filename, max_len_sec, start=None, indexes=None): | |
| audio, sr = sf.read( | |
| filename, | |
| start=math.ceil(start * self.audio_rate), | |
| frames=math.ceil(self.audio_rate * max_len_sec), | |
| always_2d=True, | |
| ) # e.g (16000, 1) | |
| audio = audio.T # (1, 16000) | |
| assert sr == self.audio_rate, ( | |
| f"Audio rate is {sr} but should be {self.audio_rate}" | |
| ) | |
| audio = audio.mean(0, keepdims=True) | |
| audio = self.maybe_augment_audio(audio) | |
| audio = torch.from_numpy(audio).float() | |
| # audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=self.audio_rate) | |
| audio = trim_pad_audio(audio, self.audio_rate, max_len_sec=max_len_sec) | |
| return audio[0] | |
| def ensure_shape(self, tensors): | |
| target_length = self.samples_per_frame | |
| processed_tensors = [] | |
| for tensor in tensors: | |
| current_length = tensor.shape[1] | |
| diff = current_length - target_length | |
| assert abs(diff) <= 5, ( | |
| f"Expected shape {target_length}, but got {current_length}" | |
| ) | |
| if diff < 0: | |
| # Calculate how much padding is needed | |
| padding_needed = target_length - current_length | |
| # Pad the tensor | |
| padded_tensor = F.pad(tensor, (0, padding_needed)) | |
| processed_tensors.append(padded_tensor) | |
| elif diff > 0: | |
| # Trim the tensor | |
| trimmed_tensor = tensor[:, :target_length] | |
| processed_tensors.append(trimmed_tensor) | |
| else: | |
| # If it's already the correct size | |
| processed_tensors.append(tensor) | |
| return torch.cat(processed_tensors) | |
| def normalize_latents(self, latents): | |
| if self.data_mean is not None: | |
| # Normalize latents to 0 mean and 0.5 std | |
| latents = ((latents - self.data_mean) / self.data_std) * 0.5 | |
| return latents | |
| def convert_indexes(self, indexes_25fps, fps_from=25, fps_to=60): | |
| ratio = fps_to / fps_from | |
| indexes_60fps = [int(index * ratio) for index in indexes_25fps] | |
| return indexes_60fps | |
| def _get_frames_and_audio(self, idx): | |
| if self.load_all_possible_indexes: | |
| indexes, video_file, audio_file, land_file = self._indexes[idx] | |
| if self.audio_in_video: | |
| vr = decord.AVReader(video_file, sample_rate=self.audio_rate) | |
| else: | |
| vr = decord.VideoReader(video_file) | |
| len_video = len(vr) | |
| if "AA_processed" in video_file or "1000actors_nsv" in video_file: | |
| len_video *= 25 / 60 | |
| len_video = int(len_video) | |
| else: | |
| video_file, audio_file, land_file = self._indexes[idx] | |
| if self.audio_in_video: | |
| vr = decord.AVReader(video_file, sample_rate=self.audio_rate) | |
| else: | |
| vr = decord.VideoReader(video_file) | |
| len_video = len(vr) | |
| if "AA_processed" in video_file or "1000actors_nsv" in video_file: | |
| len_video *= 25 / 60 | |
| len_video = int(len_video) | |
| indexes = self.get_frame_indices( | |
| len_video, | |
| select_randomly=self.select_randomly, | |
| start_idx=120 if "1000actors_nsv" in video_file else 0, | |
| ) | |
| if self.get_separate_id: | |
| id_idx = np.random.randint(0, len_video) | |
| indexes.insert(0, id_idx) | |
| if "AA_processed" in video_file or "1000actors_nsv" in video_file: | |
| video_indexes = self.convert_indexes(indexes, fps_from=25, fps_to=60) | |
| audio_file = audio_file.replace("_output_output", "") | |
| if self.audio_emb_type == "wav2vec2" and "AA_processed" in video_file: | |
| audio_path_extra = ".safetensors" | |
| else: | |
| audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors" | |
| video_path_extra = f"_{self.latent_type}_512_latent.safetensors" | |
| audio_path_extra_extra = ( | |
| ".pt" if "AA_processed" in video_file else "_beats_emb.pt" | |
| ) | |
| else: | |
| video_indexes = indexes | |
| audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors" | |
| video_path_extra = f"_{self.latent_type}_512_latent.safetensors" | |
| audio_path_extra_extra = "_beats_emb.pt" | |
| emotions = None | |
| if self.use_emotions: | |
| emotions = self.get_emotions(video_file, video_indexes) | |
| if self.get_separate_id: | |
| emotions = (emotions[0][1:], emotions[1][1:], emotions[2][1:]) | |
| raw_audio = None | |
| if self.audio_in_video: | |
| raw_audio, frames_video = vr.get_batch(video_indexes) | |
| raw_audio = rearrange(self.ensure_shape(raw_audio), "f s -> (f s)") | |
| if self.use_latent and self.precomputed_latent: | |
| latent_file = video_file.replace(self.video_ext, video_path_extra).replace( | |
| self.video_folder, self.latent_folder | |
| ) | |
| frames = load_file(latent_file)["latents"][video_indexes, :, :, :] | |
| if frames.shape[-1] != 64: | |
| print(f"Frames shape: {frames.shape}, video file: {video_file}") | |
| frames = rearrange(frames, "t c h w -> c t h w") * self.latent_scale | |
| frames = self.normalize_latents(frames) | |
| else: | |
| if self.audio_in_video: | |
| frames = frames_video.permute(3, 0, 1, 2).float() | |
| else: | |
| frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float() | |
| if raw_audio is None: | |
| # Audio is not in video | |
| raw_audio = self._load_audio( | |
| audio_file, | |
| max_len_sec=frames.shape[1] / self.fps, | |
| start=indexes[0] / self.fps, | |
| # indexes=indexes, | |
| ) | |
| if not self.from_audio_embedding: | |
| audio = raw_audio | |
| audio_frames = rearrange(audio, "(f s) -> f s", s=self.samples_per_frame) | |
| else: | |
| audio = load_file( | |
| audio_file.replace(self.audio_folder, self.audio_emb_folder).split(".")[ | |
| 0 | |
| ] | |
| + audio_path_extra | |
| )["audio"] | |
| audio_frames = audio[indexes, :] | |
| if self.add_extra_audio_emb: | |
| audio_extra = torch.load( | |
| audio_file.replace(self.audio_folder, self.audio_emb_folder).split( | |
| "." | |
| )[0] | |
| + audio_path_extra_extra | |
| ) | |
| audio_extra = audio_extra[indexes, :] | |
| audio_frames = torch.cat([audio_frames, audio_extra], dim=-1) | |
| audio_frames = ( | |
| audio_frames[1:] if self.need_cond else audio_frames | |
| ) # Remove audio of first frame | |
| if self.get_original_frames: | |
| original_frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float() | |
| original_frames = self.scale_and_crop((original_frames / 255.0) * 2 - 1) | |
| original_frames = ( | |
| original_frames[:, 1:] if self.need_cond else original_frames | |
| ) | |
| else: | |
| original_frames = None | |
| if not self.use_latent or (self.use_latent and not self.precomputed_latent): | |
| frames = self.scale_and_crop((frames / 255.0) * 2 - 1) | |
| target = frames[:, 1:] if self.need_cond else frames | |
| if self.mode == "prediction": | |
| if self.use_latent: | |
| if self.audio_in_video: | |
| clean_cond = ( | |
| frames_video[0].unsqueeze(0).permute(3, 0, 1, 2).float() | |
| ) | |
| else: | |
| clean_cond = ( | |
| vr[video_indexes[0]].unsqueeze(0).permute(3, 0, 1, 2).float() | |
| ) | |
| original_size = clean_cond.shape[-2:] | |
| clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1).squeeze( | |
| 0 | |
| ) | |
| if self.latent_condition: | |
| noisy_cond = frames[:, 0] | |
| else: | |
| noisy_cond = clean_cond | |
| else: | |
| clean_cond = frames[:, 0] | |
| noisy_cond = clean_cond | |
| elif self.mode == "interpolation": | |
| if self.use_latent: | |
| if self.audio_in_video: | |
| clean_cond = frames_video[[0, -1]].permute(3, 0, 1, 2).float() | |
| else: | |
| clean_cond = ( | |
| vr.get_batch([video_indexes[0], video_indexes[-1]]) | |
| .permute(3, 0, 1, 2) | |
| .float() | |
| ) | |
| original_size = clean_cond.shape[-2:] | |
| clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1) | |
| if self.latent_condition: | |
| noisy_cond = torch.stack([target[:, 0], target[:, -1]], dim=1) | |
| else: | |
| noisy_cond = clean_cond | |
| else: | |
| clean_cond = torch.stack([target[:, 0], target[:, -1]], dim=1) | |
| noisy_cond = clean_cond | |
| # Add noise to conditional frame | |
| if self.cond_noise and isinstance(self.cond_noise, ListConfig): | |
| cond_noise = ( | |
| self.cond_noise[0] + self.cond_noise[1] * torch.randn((1,)) | |
| ).exp() | |
| noisy_cond = noisy_cond + cond_noise * torch.randn_like(noisy_cond) | |
| else: | |
| noisy_cond = noisy_cond + self.cond_noise * torch.randn_like(noisy_cond) | |
| cond_noise = self.cond_noise | |
| if self.get_masks: | |
| target_size = ( | |
| (self.resize_size, self.resize_size) | |
| if not self.use_latent | |
| else (self.resize_size // 8, self.resize_size // 8) | |
| ) | |
| masks, landmarks = self._load_landmarks( | |
| land_file, original_size, target_size, video_indexes | |
| ) | |
| landmarks = None | |
| masks = ( | |
| masks.permute(1, 0, 2, 3)[:, 1:] | |
| if self.need_cond | |
| else masks.permute(1, 0, 2, 3) | |
| ) | |
| else: | |
| masks = None | |
| landmarks = None | |
| return ( | |
| original_frames, | |
| clean_cond, | |
| noisy_cond, | |
| target, | |
| audio_frames, | |
| raw_audio, | |
| cond_noise, | |
| emotions, | |
| masks, | |
| landmarks, | |
| ) | |
| def filter_by_length(self, video_filelist, audio_filelist): | |
| def with_opencv(filename): | |
| video = cv2.VideoCapture(filename) | |
| frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT) | |
| return int(frame_count) | |
| filtered_video = [] | |
| filtered_audio = [] | |
| min_length = (self.num_frames - 1) * (self.skip_frames + 1) + 1 | |
| for vid_file, audio_file in tqdm( | |
| zip(video_filelist, audio_filelist), | |
| total=len(video_filelist), | |
| desc="Filtering", | |
| ): | |
| # vr = decord.VideoReader(vid_file) | |
| len_video = with_opencv(vid_file) | |
| # Short videos | |
| if len_video < min_length: | |
| continue | |
| filtered_video.append(vid_file) | |
| filtered_audio.append(audio_file) | |
| print(f"New number of files: {len(filtered_video)}") | |
| return filtered_video, filtered_audio | |
| def _get_indexes(self, video_filelist, audio_filelist): | |
| indexes = [] | |
| self.og_shape = None | |
| for vid_file, audio_file in zip(video_filelist, audio_filelist): | |
| vr = decord.VideoReader(vid_file) | |
| if self.og_shape is None: | |
| self.og_shape = vr[0].shape[-2] | |
| len_video = len(vr) | |
| # Short videos | |
| if len_video < self.num_frames: | |
| continue | |
| else: | |
| possible_indexes = list( | |
| sliding_window(range(len_video), self.num_frames) | |
| )[:: self.step] | |
| possible_indexes = list( | |
| map(lambda x: (x, vid_file, audio_file), possible_indexes) | |
| ) | |
| indexes.extend(possible_indexes) | |
| print("Indexes", len(indexes), "\n") | |
| return indexes | |
| def scale_and_crop(self, video): | |
| h, w = video.shape[-2], video.shape[-1] | |
| # scale shorter side to resolution | |
| if self.resize_size is not None: | |
| scale = self.resize_size / min(h, w) | |
| if h < w: | |
| target_size = (self.resize_size, math.ceil(w * scale)) | |
| else: | |
| target_size = (math.ceil(h * scale), self.resize_size) | |
| video = F.interpolate( | |
| video, | |
| size=target_size, | |
| mode="bilinear", | |
| align_corners=False, | |
| antialias=True, | |
| ) | |
| # center crop | |
| h, w = video.shape[-2], video.shape[-1] | |
| w_start = (w - self.resize_size) // 2 | |
| h_start = (h - self.resize_size) // 2 | |
| video = video[ | |
| :, | |
| :, | |
| h_start : h_start + self.resize_size, | |
| w_start : w_start + self.resize_size, | |
| ] | |
| return self.maybe_augment(video) | |
| def _calculate_weights(self): | |
| aa_processed_count = sum( | |
| 1 | |
| for item in self._indexes | |
| if "AA_processed" in (item[1] if len(item) == 3 else item[0]) | |
| ) | |
| nsv_processed_count = sum( | |
| 1 | |
| for item in self._indexes | |
| if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0]) | |
| ) | |
| other_count = len(self._indexes) - aa_processed_count - nsv_processed_count | |
| aa_processed_weight = 1 / aa_processed_count if aa_processed_count > 0 else 0 | |
| nsv_processed_weight = 1 / nsv_processed_count if nsv_processed_count > 0 else 0 | |
| other_weight = 1 / other_count if other_count > 0 else 0 | |
| print( | |
| f"AA processed count: {aa_processed_count}, NSV processed count: {nsv_processed_count}, other count: {other_count}" | |
| ) | |
| print(f"AA processed weight: {aa_processed_weight}") | |
| print(f"NSV processed weight: {nsv_processed_weight}") | |
| print(f"Other weight: {other_weight}") | |
| weights = [ | |
| aa_processed_weight | |
| if "AA_processed" in (item[1] if len(item) == 3 else item[0]) | |
| else nsv_processed_weight | |
| if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0]) | |
| else other_weight | |
| for item in self._indexes | |
| ] | |
| return weights | |
| def __getitem__(self, idx): | |
| if self.balance_datasets: | |
| idx = self.sampler.__iter__().__next__() | |
| try: | |
| ( | |
| original_frames, | |
| clean_cond, | |
| noisy_cond, | |
| target, | |
| audio, | |
| raw_audio, | |
| cond_noise, | |
| emotions, | |
| masks, | |
| landmarks, | |
| ) = self._get_frames_and_audio(idx % len(self._indexes)) | |
| except Exception as e: | |
| print(f"Error with index {idx}: {e}") | |
| return self.__getitem__(np.random.randint(0, len(self))) | |
| out_data = {} | |
| if original_frames is not None: | |
| out_data["original_frames"] = original_frames | |
| if audio is not None: | |
| out_data["audio_emb"] = audio | |
| out_data["raw_audio"] = raw_audio | |
| if self.use_emotions: | |
| out_data["valence"] = emotions[0] | |
| out_data["arousal"] = emotions[1] | |
| out_data["emo_labels"] = emotions[2] | |
| if self.use_latent: | |
| input_key = "latents" | |
| else: | |
| input_key = "frames" | |
| out_data[input_key] = target | |
| if noisy_cond is not None: | |
| out_data["cond_frames"] = noisy_cond | |
| out_data["cond_frames_without_noise"] = clean_cond | |
| if cond_noise is not None: | |
| out_data["cond_aug"] = cond_noise | |
| if masks is not None: | |
| out_data["masks"] = masks | |
| out_data["gt"] = target | |
| if landmarks is not None: | |
| out_data["landmarks"] = landmarks | |
| out_data["motion_bucket_id"] = torch.tensor([self.motion_id]) | |
| out_data["fps_id"] = torch.tensor([self.fps - 1]) | |
| out_data["num_video_frames"] = self.num_frames | |
| out_data["image_only_indicator"] = torch.zeros(self.num_frames) | |
| return out_data | |
| if __name__ == "__main__": | |
| import torchvision.transforms as transforms | |
| import cv2 | |
| transform = transforms.Compose(transforms=[transforms.Resize((256, 256))]) | |
| dataset = VideoDataset( | |
| "/vol/paramonos2/projects/antoni/datasets/mahnob/filelist_videos_val.txt", | |
| transform=transform, | |
| num_frames=25, | |
| ) | |
| print(len(dataset)) | |
| idx = np.random.randint(0, len(dataset)) | |
| for i in range(10): | |
| print(dataset[i][0].shape, dataset[i][1].shape) | |
| image_identity = (dataset[idx][0].permute(1, 2, 0).numpy() + 1) / 2 * 255 | |
| image_other = (dataset[idx][1][:, -1].permute(1, 2, 0).numpy() + 1) / 2 * 255 | |
| cv2.imwrite("image_identity.png", image_identity[:, :, ::-1]) | |
| for i in range(25): | |
| image = (dataset[idx][1][:, i].permute(1, 2, 0).numpy() + 1) / 2 * 255 | |
| cv2.imwrite(f"tmp_vid_dataset/image_{i}.png", image[:, :, ::-1]) | |