|
import os |
|
import pdb |
|
|
|
import torch |
|
import numpy as np |
|
import pickle |
|
from tqdm import tqdm |
|
from transformers import Wav2Vec2Processor |
|
import librosa |
|
from collections import defaultdict |
|
from torch.utils import data |
|
|
|
|
|
class Dataset(data.Dataset): |
|
"""Custom data.Dataset compatible with data.DataLoader.""" |
|
def __init__(self, data, subjects_dict, data_type="train", read_audio=False): |
|
self.data = data |
|
self.len = len(self.data) |
|
self.subjects_dict = subjects_dict |
|
self.data_type = data_type |
|
|
|
self.one_hot_labels = np.eye(len(subjects_dict["train"])) |
|
self.read_audio = read_audio |
|
|
|
def __getitem__(self, index): |
|
"""Returns one data pair (source and target).""" |
|
|
|
file_name = self.data[index]["name"] |
|
audio = self.data[index]["audio"] |
|
vertice = self.data[index]["vertice"] |
|
template = self.data[index]["template"] |
|
if self.data_type == "train": |
|
if len(self.one_hot_labels)==1: |
|
one_hot = self.one_hot_labels[0] |
|
else: |
|
|
|
subject = file_name.split("_")[0] |
|
one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject.capitalize())] |
|
|
|
else: |
|
|
|
if len(self.one_hot_labels)==1: |
|
one_hot = self.one_hot_labels[0] |
|
else: |
|
subject = file_name.split("_")[0] |
|
one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject.capitalize())] |
|
|
|
if self.read_audio: |
|
return torch.FloatTensor(audio), torch.FloatTensor(vertice), torch.FloatTensor(template), torch.FloatTensor(one_hot), file_name |
|
else: |
|
return torch.FloatTensor(vertice), torch.FloatTensor(template), torch.FloatTensor(one_hot), file_name |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
def read_data(args, test_config=False): |
|
print("Loading data...") |
|
data = defaultdict(dict) |
|
train_data = [] |
|
valid_data = [] |
|
test_data = [] |
|
|
|
audio_path = os.path.join(args.data_root, args.wav_path) |
|
vertices_path = os.path.join(args.data_root, args.vertices_path) |
|
if args.read_audio: |
|
|
|
processor = Wav2Vec2Processor.from_pretrained(args.wav2vec2model_path) |
|
|
|
template_file = os.path.join(args.data_root, args.template_file) |
|
with open(template_file, 'rb') as fin: |
|
templates = pickle.load(fin, encoding='latin1') |
|
|
|
cnt=0 |
|
|
|
|
|
train_txt = open(os.path.join(args.data_root,"train.txt"), "r") |
|
test_txt = open(os.path.join(args.data_root,"test.txt"), "r") |
|
train_lines, test_lines, train_list, test_list = train_txt.readlines(), test_txt.readlines(), [], [] |
|
for tt in train_lines: |
|
train_list.append(tt.split("\n")[0]) |
|
for tt in test_lines: |
|
test_list.append(tt.split("\n")[0]) |
|
|
|
for r, ds, fs in os.walk(audio_path): |
|
|
|
for f in tqdm(fs): |
|
|
|
if test_config and f not in test_list: |
|
continue |
|
|
|
if f.endswith("wav"): |
|
if args.read_audio: |
|
wav_path = os.path.join(r, f) |
|
speech_array, sampling_rate = librosa.load(wav_path, sr=16000) |
|
input_values = np.squeeze(processor(speech_array, sampling_rate=16000).input_values) |
|
key = f.replace("wav", "npy") |
|
data[key]["audio"] = input_values if args.read_audio else None |
|
subject_id = "_".join(key.split("_")[:-1]) |
|
|
|
temp = templates["id"] |
|
|
|
data[key]["name"] = f |
|
data[key]["template"] = temp.reshape((-1)) |
|
|
|
vertice_path = os.path.join(vertices_path, f.replace("wav", "npz")) |
|
|
|
if not os.path.exists(vertice_path): |
|
del data[key] |
|
else: |
|
if args.dataset == "vocaset": |
|
data[key]["vertice"] = np.load(vertice_path, allow_pickle=True)[::2, |
|
:] |
|
elif args.dataset == "BIWI": |
|
data[key]["vertice"] = np.load(vertice_path, allow_pickle=True) |
|
elif args.dataset=="multi": |
|
flame_param = np.load(vertice_path, allow_pickle=True) |
|
data[key]["vertice"] = flame_param["verts"].reshape((flame_param["verts"].shape[0], -1)) |
|
|
|
subjects_dict = {} |
|
subjects_dict["train"] = [i for i in args.train_subjects.split(" ")] |
|
subjects_dict["val"] = [i for i in args.val_subjects.split(" ")] |
|
subjects_dict["test"] = [i for i in args.test_subjects.split(" ")] |
|
|
|
|
|
train_cnt = 0 |
|
for k, v in data.items(): |
|
k_wav = k.replace("npy", "wav") |
|
if k_wav in train_list: |
|
if train_cnt<int(len(train_list)*0.9): |
|
train_data.append(v) |
|
else: |
|
valid_data.append(v) |
|
train_cnt+=1 |
|
elif k_wav in test_list: |
|
test_data.append(v) |
|
|
|
print('Loaded data: Train-{}, Val-{}, Test-{}'.format(len(train_data), len(valid_data), len(test_data))) |
|
return train_data, valid_data, test_data, subjects_dict |
|
|
|
|
|
def get_dataloaders(args, test_config=False): |
|
dataset = {} |
|
train_data, valid_data, test_data, subjects_dict = read_data(args, test_config) |
|
|
|
if not test_config: |
|
train_data = Dataset(train_data, subjects_dict, "train", args.read_audio) |
|
dataset["train"] = data.DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, |
|
num_workers=args.workers) |
|
valid_data = Dataset(valid_data, subjects_dict, "val", args.read_audio) |
|
dataset["valid"] = data.DataLoader(dataset=valid_data, batch_size=1, shuffle=False, num_workers=args.workers) |
|
test_data = Dataset(test_data, subjects_dict, "test", args.read_audio) |
|
dataset["test"] = data.DataLoader(dataset=test_data, batch_size=1, shuffle=True, num_workers=args.workers) |
|
return dataset |
|
|
|
|
|
if __name__ == "__main__": |
|
get_dataloaders() |