Flux9665's picture
update to the current version
70399da
raw
history blame
3.35 kB
import os
import numpy as np
import torch
class SpeakerEmbeddingsDataset(torch.utils.data.Dataset):
def __init__(self, feature_path, device, mode='utterance'):
super(SpeakerEmbeddingsDataset, self).__init__()
modes = ['utterance', 'speaker']
assert mode in modes, f'mode: {mode} is not supported'
if mode == 'utterance':
self.mode = 'utt'
elif mode == 'speaker':
self.mode = 'spk'
self.device = device
self.x, self.speakers = self._load_features(feature_path)
# unique_speakers = set(self.speakers)
# spk2class = dict(zip(unique_speakers, range(len(unique_speakers))))
# #self.x = self._reformat_features(self.x)
# self.y = torch.tensor([spk2class[spk] for spk in self.speakers]).to(self.device)
# self.class2spk = dict(zip(spk2class.values(), spk2class.keys()))
def __len__(self):
return len(self.speakers)
def __getitem__(self, index):
embedding = self.normalize_embedding(self.x[index])
# speaker_id = self.y[index]
return embedding, torch.zeros([0])
def normalize_embedding(self, vector):
return torch.sub(vector, self.mean) / self.std
def get_speaker(self, label):
return self.class2spk[label]
def get_embedding_dim(self):
return self.x.shape[-1]
def get_num_speaker(self):
return len(torch.unique((self.y)))
def set_labels(self, labels):
self.y_old = self.y
self.y = torch.full(size=(len(self),), fill_value=labels).to(self.device)
# if isinstance(labels, int) or isinstance(labels, float):
# self.y = torch.full(size=len(self), fill_value=labels)
# elif len(labels) == len(self):
# self.y = torch.tensor(labels)
def _load_features(self, feature_path):
if os.path.isfile(feature_path):
vectors = torch.load(feature_path, map_location=self.device)
if isinstance(vectors, list):
vectors = torch.stack(vectors)
self.mean = torch.mean(vectors)
self.std = torch.std(vectors)
return vectors, torch.zeros(vectors.size(0))
else:
vectors = torch.load(feature_path, map_location=self.device)
self.mean = torch.mean(vectors)
self.std = torch.std(vectors)
spk2idx = {}
with open(feature_path / f'{self.mode}2idx', 'r') as f:
for line in f:
split_line = line.strip().split()
if len(split_line) == 2:
spk2idx[split_line[0].strip()] = int(split_line[1])
speakers, indices = zip(*spk2idx.items())
if (feature_path / 'utt2spk').exists(): # spk2idx contains utt_ids not speaker_ids
utt2spk = {}
with open(feature_path / 'utt2spk', 'r') as f:
for line in f:
split_line = line.strip().split()
if len(split_line) == 2:
utt2spk[split_line[0].strip()] = split_line[1].strip()
speakers = [utt2spk[utt] for utt in speakers]
return vectors[np.array(indices)], speakers
def _reformat_features(self, features):
if len(features.shape) == 2:
return features.reshape(features.shape[0], 1, 1, features.shape[1])