File size: 3,348 Bytes
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os

import numpy as np
import torch


class SpeakerEmbeddingsDataset(torch.utils.data.Dataset):

    def __init__(self, feature_path, device, mode='utterance'):
        super(SpeakerEmbeddingsDataset, self).__init__()

        modes = ['utterance', 'speaker']
        assert mode in modes, f'mode: {mode} is not supported'
        if mode == 'utterance':
            self.mode = 'utt'
        elif mode == 'speaker':
            self.mode = 'spk'

        self.device = device

        self.x, self.speakers = self._load_features(feature_path)
        # unique_speakers = set(self.speakers)
        # spk2class = dict(zip(unique_speakers, range(len(unique_speakers))))
        # #self.x = self._reformat_features(self.x)
        # self.y = torch.tensor([spk2class[spk] for spk in self.speakers]).to(self.device)
        # self.class2spk = dict(zip(spk2class.values(), spk2class.keys()))

    def __len__(self):
        return len(self.speakers)

    def __getitem__(self, index):
        embedding = self.normalize_embedding(self.x[index])
        # speaker_id = self.y[index]
        return embedding, torch.zeros([0])

    def normalize_embedding(self, vector):
        return torch.sub(vector, self.mean) / self.std

    def get_speaker(self, label):
        return self.class2spk[label]

    def get_embedding_dim(self):
        return self.x.shape[-1]

    def get_num_speaker(self):
        return len(torch.unique((self.y)))

    def set_labels(self, labels):
        self.y_old = self.y
        self.y = torch.full(size=(len(self),), fill_value=labels).to(self.device)
        # if isinstance(labels, int) or isinstance(labels, float):
        #    self.y = torch.full(size=len(self), fill_value=labels)
        # elif len(labels) == len(self):
        #    self.y = torch.tensor(labels)

    def _load_features(self, feature_path):
        if os.path.isfile(feature_path):
            vectors = torch.load(feature_path, map_location=self.device)
            if isinstance(vectors, list):
                vectors = torch.stack(vectors)

            self.mean = torch.mean(vectors)
            self.std = torch.std(vectors)
            return vectors, torch.zeros(vectors.size(0))
        else:
            vectors = torch.load(feature_path, map_location=self.device)

        self.mean = torch.mean(vectors)
        self.std = torch.std(vectors)

        spk2idx = {}
        with open(feature_path / f'{self.mode}2idx', 'r') as f:
            for line in f:
                split_line = line.strip().split()
                if len(split_line) == 2:
                    spk2idx[split_line[0].strip()] = int(split_line[1])

        speakers, indices = zip(*spk2idx.items())

        if (feature_path / 'utt2spk').exists():  # spk2idx contains utt_ids not speaker_ids
            utt2spk = {}
            with open(feature_path / 'utt2spk', 'r') as f:
                for line in f:
                    split_line = line.strip().split()
                    if len(split_line) == 2:
                        utt2spk[split_line[0].strip()] = split_line[1].strip()

            speakers = [utt2spk[utt] for utt in speakers]

        return vectors[np.array(indices)], speakers

    def _reformat_features(self, features):
        if len(features.shape) == 2:
            return features.reshape(features.shape[0], 1, 1, features.shape[1])