import os import pdb import torch import numpy as np import pickle from tqdm import tqdm from transformers import Wav2Vec2Processor import librosa from collections import defaultdict from torch.utils import data class Dataset(data.Dataset): """Custom data.Dataset compatible with data.DataLoader.""" def __init__(self, data, subjects_dict, data_type="train", read_audio=False): self.data = data self.len = len(self.data) self.subjects_dict = subjects_dict self.data_type = data_type #self.one_hot_labels = np.eye(1) self.one_hot_labels = np.eye(len(subjects_dict["train"])) self.read_audio = read_audio def __getitem__(self, index): """Returns one data pair (source and target).""" # seq_len, fea_dim file_name = self.data[index]["name"] audio = self.data[index]["audio"] vertice = self.data[index]["vertice"] template = self.data[index]["template"] if self.data_type == "train": if len(self.one_hot_labels)==1: one_hot = self.one_hot_labels[0] else: #subject = "_".join(file_name.split("_")[:-1]) subject = file_name.split("_")[0] one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject.capitalize())] else: #one_hot = self.one_hot_labels if len(self.one_hot_labels)==1: one_hot = self.one_hot_labels[0] else: subject = file_name.split("_")[0] one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject.capitalize())] if self.read_audio: return torch.FloatTensor(audio), torch.FloatTensor(vertice), torch.FloatTensor(template), torch.FloatTensor(one_hot), file_name else: return torch.FloatTensor(vertice), torch.FloatTensor(template), torch.FloatTensor(one_hot), file_name def __len__(self): return self.len def read_data(args, test_config=False): print("Loading data...") data = defaultdict(dict) train_data = [] valid_data = [] test_data = [] audio_path = os.path.join(args.data_root, args.wav_path) vertices_path = os.path.join(args.data_root, args.vertices_path) if args.read_audio: # read_audio==False when training vq to save time processor = Wav2Vec2Processor.from_pretrained(args.wav2vec2model_path) template_file = os.path.join(args.data_root, args.template_file) with open(template_file, 'rb') as fin: templates = pickle.load(fin, encoding='latin1') cnt=0 ####spliting train, val, test train_txt = open(os.path.join(args.data_root,"train.txt"), "r") test_txt = open(os.path.join(args.data_root,"test.txt"), "r") train_lines, test_lines, train_list, test_list = train_txt.readlines(), test_txt.readlines(), [], [] for tt in train_lines: train_list.append(tt.split("\n")[0]) for tt in test_lines: test_list.append(tt.split("\n")[0]) for r, ds, fs in os.walk(audio_path): for f in tqdm(fs): ###Activate when testing the model if test_config and f not in test_list: continue if f.endswith("wav"): if args.read_audio: wav_path = os.path.join(r, f) speech_array, sampling_rate = librosa.load(wav_path, sr=16000) input_values = np.squeeze(processor(speech_array, sampling_rate=16000).input_values) key = f.replace("wav", "npy") data[key]["audio"] = input_values if args.read_audio else None subject_id = "_".join(key.split("_")[:-1]) #temp = templates[subject_id] temp = templates["id"] data[key]["name"] = f data[key]["template"] = temp.reshape((-1)) vertice_path = os.path.join(vertices_path, f.replace("wav", "npz")) if not os.path.exists(vertice_path): del data[key] else: if args.dataset == "vocaset": data[key]["vertice"] = np.load(vertice_path, allow_pickle=True)[::2, :] # due to the memory limit elif args.dataset == "BIWI": data[key]["vertice"] = np.load(vertice_path, allow_pickle=True) elif args.dataset=="multi": flame_param = np.load(vertice_path, allow_pickle=True) data[key]["vertice"] = flame_param["verts"].reshape((flame_param["verts"].shape[0], -1)) subjects_dict = {} subjects_dict["train"] = [i for i in args.train_subjects.split(" ")] subjects_dict["val"] = [i for i in args.val_subjects.split(" ")] subjects_dict["test"] = [i for i in args.test_subjects.split(" ")] # train vq and pred train_cnt = 0 for k, v in data.items(): k_wav = k.replace("npy", "wav") if k_wav in train_list: if train_cnt