import os import tqdm import torch import torchaudio import numpy as np from torch.utils.data import DataLoader from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, Wav2Vec2Processor from torch.nn import functional as F class CustomDataset(torch.utils.data.Dataset): def __init__(self, dataset, basedir=None, sampling_rate=16000, max_audio_len=5): self.dataset = dataset self.basedir = basedir self.sampling_rate = sampling_rate self.max_audio_len = max_audio_len def __len__(self): return len(self.dataset) def _cutorpad(self, audio): effective_length = self.sampling_rate * self.max_audio_len len_audio = len(audio) if len_audio > effective_length: audio = audio[:effective_length] return audio def __getitem__(self, index): if self.basedir is None: filepath = self.dataset[index] else: filepath = os.path.join(self.basedir, self.dataset[index]) speech_array, sr = torchaudio.load(filepath) if speech_array.shape[0] > 1: speech_array = torch.mean(speech_array, dim=0, keepdim=True) if sr != self.sampling_rate: transform = torchaudio.transforms.Resample(sr, self.sampling_rate) speech_array = transform(speech_array) sr = self.sampling_rate speech_array = speech_array.squeeze().numpy() speech_array = self._cutorpad(speech_array) return {"input_values": speech_array, "attention_mask": None} class CollateFunc: def __init__(self, processor, max_length=None, padding=True, pad_to_multiple_of=None, sampling_rate=16000): self.padding = padding self.processor = processor self.max_length = max_length self.sampling_rate = sampling_rate self.pad_to_multiple_of = pad_to_multiple_of def __call__(self, batch): input_features = [] for audio in batch: input_tensor = self.processor(audio["input_values"], sampling_rate=self.sampling_rate).input_values input_tensor = np.squeeze(input_tensor) input_features.append({"input_values": input_tensor}) batch = self.processor.pad( input_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) return batch def predict(test_dataloader, model, device): model.to(device) model.eval() preds = [] with torch.no_grad(): for batch in tqdm.tqdm(test_dataloader): input_values = batch['input_values'].to(device) logits = model(input_values).logits scores = F.softmax(logits, dim=-1) pred = torch.argmax(scores, dim=1).cpu().detach().numpy() preds.extend(pred) return preds def get_gender(model_name_or_path, audio_paths, device): num_labels = 2 feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path) model = AutoModelForAudioClassification.from_pretrained( pretrained_model_name_or_path=model_name_or_path, num_labels=num_labels, ) test_dataset = CustomDataset(audio_paths) data_collator = CollateFunc( processor=feature_extractor, padding=True, sampling_rate=16000, ) test_dataloader = DataLoader( dataset=test_dataset, batch_size=16, collate_fn=data_collator, shuffle=False, num_workers=10 ) preds = predict(test_dataloader=test_dataloader, model=model, device=device) # Map class indices to labels label_mapping = {0: "female", 1: "male"} # Determine the most common predicted label most_common_label = max(set(preds), key=preds.count) predicted_label = label_mapping[most_common_label] return predicted_label