Spaces:
Running
on
T4
Running
on
T4
import os | |
import numpy as np | |
import torch | |
class SpeakerEmbeddingsDataset(torch.utils.data.Dataset): | |
def __init__(self, feature_path, device, mode='utterance'): | |
super(SpeakerEmbeddingsDataset, self).__init__() | |
modes = ['utterance', 'speaker'] | |
assert mode in modes, f'mode: {mode} is not supported' | |
if mode == 'utterance': | |
self.mode = 'utt' | |
elif mode == 'speaker': | |
self.mode = 'spk' | |
self.device = device | |
self.x, self.speakers = self._load_features(feature_path) | |
# unique_speakers = set(self.speakers) | |
# spk2class = dict(zip(unique_speakers, range(len(unique_speakers)))) | |
# #self.x = self._reformat_features(self.x) | |
# self.y = torch.tensor([spk2class[spk] for spk in self.speakers]).to(self.device) | |
# self.class2spk = dict(zip(spk2class.values(), spk2class.keys())) | |
def __len__(self): | |
return len(self.speakers) | |
def __getitem__(self, index): | |
embedding = self.normalize_embedding(self.x[index]) | |
# speaker_id = self.y[index] | |
return embedding, torch.zeros([0]) | |
def normalize_embedding(self, vector): | |
return torch.sub(vector, self.mean) / self.std | |
def get_speaker(self, label): | |
return self.class2spk[label] | |
def get_embedding_dim(self): | |
return self.x.shape[-1] | |
def get_num_speaker(self): | |
return len(torch.unique((self.y))) | |
def set_labels(self, labels): | |
self.y_old = self.y | |
self.y = torch.full(size=(len(self),), fill_value=labels).to(self.device) | |
# if isinstance(labels, int) or isinstance(labels, float): | |
# self.y = torch.full(size=len(self), fill_value=labels) | |
# elif len(labels) == len(self): | |
# self.y = torch.tensor(labels) | |
def _load_features(self, feature_path): | |
if os.path.isfile(feature_path): | |
vectors = torch.load(feature_path, map_location=self.device) | |
if isinstance(vectors, list): | |
vectors = torch.stack(vectors) | |
self.mean = torch.mean(vectors) | |
self.std = torch.std(vectors) | |
return vectors, torch.zeros(vectors.size(0)) | |
else: | |
vectors = torch.load(feature_path, map_location=self.device) | |
self.mean = torch.mean(vectors) | |
self.std = torch.std(vectors) | |
spk2idx = {} | |
with open(feature_path / f'{self.mode}2idx', 'r') as f: | |
for line in f: | |
split_line = line.strip().split() | |
if len(split_line) == 2: | |
spk2idx[split_line[0].strip()] = int(split_line[1]) | |
speakers, indices = zip(*spk2idx.items()) | |
if (feature_path / 'utt2spk').exists(): # spk2idx contains utt_ids not speaker_ids | |
utt2spk = {} | |
with open(feature_path / 'utt2spk', 'r') as f: | |
for line in f: | |
split_line = line.strip().split() | |
if len(split_line) == 2: | |
utt2spk[split_line[0].strip()] = split_line[1].strip() | |
speakers = [utt2spk[utt] for utt in speakers] | |
return vectors[np.array(indices)], speakers | |
def _reformat_features(self, features): | |
if len(features.shape) == 2: | |
return features.reshape(features.shape[0], 1, 1, features.shape[1]) | |