colab_test_model / handler.py
gautamtata's picture
Update handler.py
5d7fda2
raw
history blame
6.13 kB
from transformers import AutoConfig, Wav2Vec2Processor
from torch import nn
import torch
import torchaudio
import torch.nn.functional as F
from typing import Dict, List, Any
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2PreTrainedModel,
Wav2Vec2Model
)
class Wav2Vec2ClassificationHead(nn.Module):
"""Head for wav2vec classification task."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.pooling_mode = config.pooling_mode
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.classifier = Wav2Vec2ClassificationHead(config)
self.init_weights()
def freeze_feature_extractor(self):
self.wav2vec2.feature_extractor._freeze_parameters()
def merged_strategy(
self,
hidden_states,
mode="mean"
):
if mode == "mean":
outputs = torch.mean(hidden_states, dim=1)
elif mode == "sum":
outputs = torch.sum(hidden_states, dim=1)
elif mode == "max":
outputs = torch.max(hidden_states, dim=1)[0]
else:
raise Exception(
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
return outputs
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SpeechClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Assuming the provided predict and related functions are part of your handler
class EndpointHandler():
def __init__(self, model_path=""):
# Here we load the model and processor.
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.config = AutoConfig.from_pretrained(f"{model_path}/config.json")
self.processor = Wav2Vec2Processor.from_pretrained(model_path)
self.model = Wav2Vec2ForSpeechClassification.from_pretrained(model_path).to(self.device)
def speech_file_to_array_fn(self, path):
sampling_rate = self.processor.feature_extractor.sampling_rate
speech_array, _sampling_rate = torchaudio.load(path)
resampler = torchaudio.transforms.Resample(_sampling_rate, sampling_rate)
speech = resampler(speech_array).squeeze().numpy()
return speech
def predict(self, path):
speech = self.speech_file_to_array_fn(path)
features = self.processor(speech, sampling_rate=self.processor.feature_extractor.sampling_rate,
return_tensors="pt", padding=True)
input_values = features.input_values.to(self.device)
attention_mask = features.attention_mask.to(self.device)
with torch.no_grad():
logits = self.model(input_values, attention_mask=attention_mask).logits
scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
outputs = [{"label": self.config.id2label[i], "score": score} for i, score in enumerate(scores)]
return outputs
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
The actual method called during inference. Expects data to have a 'path' to the audio file.
"""
# Get the path to the audio file from the request data
path = data.get("path")
# If the path is provided, we run the prediction, else return an error message
if path:
return self.predict(path)
else:
return {"error": "Path to the audio file is required."}