Spaces:
Sleeping
Sleeping
from torch.utils.data import Dataset | |
from tqdm.auto import tqdm | |
import os | |
import librosa | |
import numpy as np | |
import torch | |
import random | |
from numpy.linalg import norm | |
from utils.VAD_segments import VAD_chunk | |
from utils.hparam import hparam as hp | |
class GujaratiSpeakerVerificationDatasetTest(Dataset): | |
def __init__(self, path, shuffle=True, utter_start=0): | |
# data path | |
self.path = path | |
self.file_list = os.listdir(self.path) | |
self.shuffle=shuffle | |
self.utter_start = utter_start | |
self.utter_num = 4 | |
def __len__(self): | |
return len(self.file_list) | |
def __getitem__(self, idx): | |
np_file_list = self.file_list | |
selected_file = np_file_list[idx] | |
utters = np.load(os.path.join(self.path, selected_file)) | |
# load utterance spectrogram of selected speaker | |
if self.shuffle: | |
utter_index = np.random.randint(0, utters.shape[0], self.utter_num) # select M utterances per speaker | |
utterance = utters[utter_index] | |
else: | |
utterance = utters[self.utter_start: self.utter_start+self.utter_num] # utterances of a speaker [batch(M), n_mels, frames] | |
utterance = utterance[:,:,:160] # TODO implement variable length batch size | |
utterance = torch.tensor(np.transpose(utterance, axes=(0,2,1))) # transpose [batch, frames, n_mels] | |
return utterance | |
def concat_segs(times, segs): | |
concat_seg = [] | |
seg_concat = segs[0] | |
for i in range(0, len(times)-1): | |
if times[i][1] == times[i+1][0]: | |
seg_concat = np.concatenate((seg_concat, segs[i+1])) | |
else: | |
concat_seg.append(seg_concat) | |
seg_concat = segs[i+1] | |
else: | |
concat_seg.append(seg_concat) | |
return concat_seg | |
def get_STFTs(segs): | |
sr = 16000 | |
STFT_frames = [] | |
for seg in segs: | |
S = librosa.core.stft(y=seg, n_fft=hp.data.nfft, | |
win_length=int(hp.data.window * sr), hop_length=int(hp.data.hop * sr)) | |
S = np.abs(S)**2 | |
mel_basis = librosa.filters.mel(sr=sr, n_fft=hp.data.nfft, n_mels=hp.data.nmels) | |
S = np.log10(np.dot(mel_basis, S) + 1e-6) | |
for j in range(0, S.shape[1], int(.12/hp.data.hop)): | |
if j + 24 < S.shape[1]: | |
STFT_frames.append(S[:, j:j+24]) | |
else: | |
break | |
return STFT_frames | |
def get_embedding(file_path, embedder_net, device, n_threshold=-1): | |
times, segs = VAD_chunk(2, file_path) | |
if not segs: | |
print(f'No voice activity detected in {file_path}') | |
return None | |
concat_seg = concat_segs(times, segs) | |
if not concat_seg: | |
print(f'No concatenated segments for {file_path}') | |
return None | |
STFT_frames = get_STFTs(concat_seg) | |
if not STFT_frames: | |
#print(f'No STFT frames for {file_path}') | |
return None | |
STFT_frames = np.stack(STFT_frames, axis=2) | |
STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device) | |
with torch.no_grad(): | |
embeddings = embedder_net(STFT_frames) | |
embeddings = embeddings[:n_threshold, :] | |
avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy() | |
return avg_embedding | |
def get_speaker_embeddings_listdir(embedder_net, device, list_dir, k): | |
speaker_embeddings = {} | |
for speaker_name in tqdm(list_dir, leave = False): | |
speaker_dir = speaker_name | |
if os.path.isdir(speaker_dir) and speaker_dir[0] != ".DS_Store": | |
speaker_embeddings[speaker_name] = [] | |
for i in range(10): | |
embeddings = [] | |
audio_files = [os.path.join(speaker_dir, f) for f in os.listdir(speaker_dir) if f.endswith('.wav')] | |
random.shuffle(audio_files) | |
count = 0 | |
iter_ = 0 | |
while(count <= k): | |
file_path = audio_files[iter_] | |
embedding = get_embedding(file_path, embedder_net, device) | |
try: | |
_ = embedding.shape | |
embeddings.append(embedding) | |
count+=1 | |
iter_+=1 | |
except: | |
iter_+=1 | |
speaker_embeddings[speaker_name].append(np.mean(embeddings, axis=0)) | |
return speaker_embeddings | |
def create_pairs(speaker_embeddings): | |
pairs = [] | |
labels = [] | |
speakers = list(speaker_embeddings.keys()) | |
for i in range(len(speakers)): | |
for j in range(len(speakers)): | |
for k1 in range(10): | |
for k2 in range(10): | |
emb1 = speaker_embeddings[speakers[i]][k1] | |
emb2 = speaker_embeddings[speakers[j]][k2] | |
pairs.append((emb1, emb2)) | |
if i == j and not((emb1 == emb2).all()): | |
labels.append(1) # Same speaker | |
else: | |
labels.append(0) # Different speakers | |
return pairs, labels | |
class EmbeddingPairDataset(Dataset): | |
def __init__(self, pairs, labels): | |
self.pairs = pairs | |
self.labels = labels | |
def __len__(self): | |
return len(self.pairs) | |
def __getitem__(self, idx): | |
emb1, emb2 = self.pairs[idx] | |
label = self.labels[idx] | |
emb1, emb2 = torch.tensor(emb1, dtype=torch.float32), torch.tensor(emb2, dtype=torch.float32) | |
concatenated = torch.cat((emb1, emb2), dim=1) | |
return concatenated.squeeze(), torch.tensor(label, dtype=torch.float32) | |
def __len__(self): | |
return len(self.labels) | |
def __repr__(self): | |
return f"{self.__class__.__name__}(length={self.__len__()})" | |
def cosine_similarity(A, B): | |
A = A.flatten().astype(np.float64) | |
B = B.flatten().astype(np.float64) | |
cosine = np.dot(A,B)/(norm(A)*norm(B)) | |
return cosine | |
def create_subset(dataset, num_zeros): | |
pairs = dataset.pairs | |
labels = dataset.labels | |
pairs_1 = [pairs[i] for i in range(len(pairs)) if labels[i] == 1] | |
labels_1 = [labels[i] for i in range(len(labels)) if labels[i] == 1] | |
pairs_0 = [pairs[i] for i in range(len(pairs)) if labels[i] == 0] | |
labels_0 = [labels[i] for i in range(len(labels)) if labels[i] == 0] | |
num_zeros = min(num_zeros, len(pairs_0)) | |
pairs_0 = pairs_0[:num_zeros] | |
labels_0 = labels_0[:num_zeros] | |
filtered_pairs = pairs_1 + pairs_0 | |
filtered_labels = labels_1 + labels_0 | |
return filtered_pairs, filtered_labels |