gender_detection / gender_prediction.py
Salman11223's picture
Create gender_prediction.py
7d66980
raw
history blame contribute delete
No virus
3.94 kB
import os
import tqdm
import torch
import torchaudio
import numpy as np
from torch.utils.data import DataLoader
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, Wav2Vec2Processor
from torch.nn import functional as F
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, dataset, basedir=None, sampling_rate=16000, max_audio_len=5):
self.dataset = dataset
self.basedir = basedir
self.sampling_rate = sampling_rate
self.max_audio_len = max_audio_len
def __len__(self):
return len(self.dataset)
def _cutorpad(self, audio):
effective_length = self.sampling_rate * self.max_audio_len
len_audio = len(audio)
if len_audio > effective_length:
audio = audio[:effective_length]
return audio
def __getitem__(self, index):
if self.basedir is None:
filepath = self.dataset[index]
else:
filepath = os.path.join(self.basedir, self.dataset[index])
speech_array, sr = torchaudio.load(filepath)
if speech_array.shape[0] > 1:
speech_array = torch.mean(speech_array, dim=0, keepdim=True)
if sr != self.sampling_rate:
transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
speech_array = transform(speech_array)
sr = self.sampling_rate
speech_array = speech_array.squeeze().numpy()
speech_array = self._cutorpad(speech_array)
return {"input_values": speech_array, "attention_mask": None}
class CollateFunc:
def __init__(self, processor, max_length=None, padding=True, pad_to_multiple_of=None, sampling_rate=16000):
self.padding = padding
self.processor = processor
self.max_length = max_length
self.sampling_rate = sampling_rate
self.pad_to_multiple_of = pad_to_multiple_of
def __call__(self, batch):
input_features = []
for audio in batch:
input_tensor = self.processor(audio["input_values"], sampling_rate=self.sampling_rate).input_values
input_tensor = np.squeeze(input_tensor)
input_features.append({"input_values": input_tensor})
batch = self.processor.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
return batch
def predict(test_dataloader, model, device):
model.to(device)
model.eval()
preds = []
with torch.no_grad():
for batch in tqdm.tqdm(test_dataloader):
input_values = batch['input_values'].to(device)
logits = model(input_values).logits
scores = F.softmax(logits, dim=-1)
pred = torch.argmax(scores, dim=1).cpu().detach().numpy()
preds.extend(pred)
return preds
def get_gender(model_name_or_path, audio_paths, device):
num_labels = 2
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)
model = AutoModelForAudioClassification.from_pretrained(
pretrained_model_name_or_path=model_name_or_path,
num_labels=num_labels,
)
test_dataset = CustomDataset(audio_paths)
data_collator = CollateFunc(
processor=feature_extractor,
padding=True,
sampling_rate=16000,
)
test_dataloader = DataLoader(
dataset=test_dataset,
batch_size=16,
collate_fn=data_collator,
shuffle=False,
num_workers=10
)
preds = predict(test_dataloader=test_dataloader, model=model, device=device)
# Map class indices to labels
label_mapping = {0: "female", 1: "male"}
# Determine the most common predicted label
most_common_label = max(set(preds), key=preds.count)
predicted_label = label_mapping[most_common_label]
return predicted_label