File size: 7,458 Bytes
5d7fda2 6d23eb9 c645001 38c2b04 986c602 c645001 6d23eb9 5d7fda2 3c03da2 38c2b04 5d7fda2 6d23eb9 5d7fda2 c645001 38c2b04 c645001 f41ae34 c645001 38c2b04 c645001 38c2b04 c645001 38c2b04 c645001 3c03da2 c645001 3c03da2 c645001 3c03da2 38c2b04 e334f76 8d6b413 e334f76 8d6b413 e334f76 c645001 e334f76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
from transformers import AutoConfig, Wav2Vec2Processor
from transformers.file_utils import ModelOutput
from dataclasses import dataclass
from torch import nn
import torch
import io
import torchaudio
import torch.nn.functional as F
from typing import Dict, List, Any, Optional, Tuple
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import requests
import tempfile
import os
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2PreTrainedModel,
Wav2Vec2Model
)
@dataclass
class SpeechClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class Wav2Vec2ClassificationHead(nn.Module):
"""Head for wav2vec classification task."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.pooling_mode = config.pooling_mode
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.classifier = Wav2Vec2ClassificationHead(config)
self.init_weights()
def freeze_feature_extractor(self):
self.wav2vec2.feature_extractor._freeze_parameters()
def merged_strategy(
self,
hidden_states,
mode="mean"
):
if mode == "mean":
outputs = torch.mean(hidden_states, dim=1)
elif mode == "sum":
outputs = torch.sum(hidden_states, dim=1)
elif mode == "max":
outputs = torch.max(hidden_states, dim=1)[0]
else:
raise Exception(
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
return outputs
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SpeechClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Assuming the provided predict and related functions are part of your handler
class EndpointHandler():
def __init__(self, model_path=""):
# Here we load the model and processor.
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.config = AutoConfig.from_pretrained(f"{model_path}/config.json")
self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-robust-ft-libri-960h")
self.model = Wav2Vec2ForSpeechClassification.from_pretrained(model_path).to(self.device)
def speech_file_to_array_fn(self, path):
sampling_rate = self.processor.feature_extractor.sampling_rate
speech_array, _sampling_rate = torchaudio.load(path)
resampler = torchaudio.transforms.Resample(_sampling_rate, sampling_rate)
speech = resampler(speech_array).squeeze().numpy()
return speech
def predict(self, path):
speech = self.speech_file_to_array_fn(path)
features = self.processor(speech, sampling_rate=self.processor.feature_extractor.sampling_rate,
return_tensors="pt", padding=True)
input_values = features.input_values.to(self.device)
attention_mask = features.attention_mask.to(self.device)
with torch.no_grad():
logits = self.model(input_values, attention_mask=attention_mask).logits
scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
outputs = [{"label": self.config.id2label[i], "score": score} for i, score in enumerate(scores)]
return outputs
def download_file(self, url):
"""
Downloads the file from the given URL and returns the path to the saved temporary file.
"""
response = requests.get(url)
if response.status_code == 200:
# Create a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
temp_file.write(response.content)
temp_file.close()
return temp_file.name
else:
return None
def __call__(self, request: Dict[str, Any]) -> List[Dict[str, Any]]:
# Get the binary content of the audio file (assuming it is passed as 'inputs')
audio_data = request.get("inputs")
if audio_data:
# Convert binary content to a bytes buffer
audio_buffer = io.BytesIO(audio_data)
# However, since this handler is not loading from a path (as it does in speech_file_to_array_fn),
# we need to read and process the buffer similar to how speech_file_to_array_fn would
# waveform, sample_rate = torchaudio.load(audio_buffer)
# waveform = waveform.squeeze().numpy()
# Call the predict function and return its results
predictions = self.predict(audio_buffer)
return predictions
else:
return {"error": "Audio input is required."} |