Fast_api / vers /find_valence.py
mulasagg's picture
Add application file
8ad2ab3
# from transformers.models.wav2vec2 import Wav2Vec2Model, Wav2Vec2FeatureExtractor
# import torchaudio
# import torch
# import torch.nn as nn
def get_valence_score(file_path):
# class VADPredictor(nn.Module):
# """Model to predict VAD Scores"""
# def __init__(self, pretrained_model_name="facebook/wav2vec2-base-960h", freeze_feature_extractor=True):
# super(VADPredictor, self).__init__()
# self.wav2vec2 = Wav2Vec2Model.from_pretrained(pretrained_model_name)
# if freeze_feature_extractor:
# for param in self.wav2vec2.feature_extractor.parameters():
# param.requires_grad = False
# hidden_size = self.wav2vec2.config.hidden_size
# self.valence_layers = nn.Sequential(
# nn.Linear(hidden_size, 256),
# nn.ReLU(),
# nn.Dropout(0.3),
# nn.Linear(256,64),
# nn.Linear(64,1)
# )
# self.arousal_layers = nn.Sequential(
# nn.Linear(hidden_size, 256),
# nn.ReLU(),
# nn.Dropout(0.3),
# nn.Linear(256,64),
# nn.Linear(64,1)
# )
# self.dominance_layers = nn.Sequential(
# nn.Linear(hidden_size, 256),
# nn.ReLU(),
# nn.Dropout(0.3),
# nn.Linear(256,64),
# nn.Linear(64,1)
# )
# def forward(self, input_values, attention_mask=None):
# outputs = self.wav2vec2(input_values, attention_mask=attention_mask)
# last_hidden_state = outputs.last_hidden_state
# pooled_output = torch.mean(last_hidden_state, dim=1)
# valence = self.valence_layers(pooled_output)
# arousal = self.arousal_layers(pooled_output)
# dominance = self.dominance_layers(pooled_output)
# return {
# 'valence': valence.squeeze(-1),
# 'arousal': arousal.squeeze(-1),
# 'dominance': dominance.squeeze(-1)
# }
# model = VADPredictor()
# model.load_state_dict(torch.load(r"D:\Intern\shankh\DUMP\vad_predictor_model.pt", map_location=torch.device("cpu")))
# model.eval()
# feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
# # Load and process audio
# file_path = file_path
# waveform, sr = torchaudio.load(file_path)
# # Convert to mono
# if waveform.shape[0] > 1:
# waveform = waveform.mean(dim=0, keepdim=True)
# # Resample to 16000 Hz
# if sr != 16000:
# resampler = torchaudio.transforms.Resample(sr, 16000)
# waveform = resampler(waveform)
# sr = 16000
# # Normalize
# waveform = waveform / waveform.abs().max()
# # Parameters
# segment_sec = 1
# segment_samples = int(segment_sec * sr)
# valence_scores = []
# # Inference per segment
# with torch.no_grad():
# for start in range(0, waveform.shape[1] - segment_samples + 1, segment_samples):
# segment = waveform[:, start:start+segment_samples]
# input_values = feature_extractor(segment.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_values
# output = model(input_values)
# val = output['valence'].item()
# valence_scores.append(val)
valence_scores = 5.0
return valence_scores