gautamtata commited on
Commit
38c2b04
·
1 Parent(s): d415c99

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +55 -0
handler.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import torch
3
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForSpeechClassification, AutoConfig
4
+ from torch.nn.functional import softmax
5
+ from typing import Dict, List, Any
6
+
7
+ # Suppose this handler is for a speech classification model
8
+ class EndpointHandler():
9
+ def __init__(self, path="."):
10
+ # Assuming that the path contains all the necessary files for model and processor.
11
+ config = AutoConfig.from_pretrained(path)
12
+ self.processor = Wav2Vec2Processor.from_pretrained(path)
13
+ self.model = Wav2Vec2ForSpeechClassification.from_pretrained(path)
14
+ self.sampling_rate = self.processor.feature_extractor.sampling_rate
15
+ self.model.to('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
+ """
19
+ Overriding call method to handle speech input and return classification result.
20
+ """
21
+ # Extract 'inputs' key from the data dictionary. This should be a path to the audio file.
22
+ audio_path = data.get('inputs', None)
23
+ if audio_path is None:
24
+ raise ValueError("Invalid input, 'inputs' key with path to the audio file is required.")
25
+
26
+ # Load and preprocess the audio file, and run prediction
27
+ outputs = self.predict(audio_path)
28
+ return outputs
29
+
30
+ def predict(self, path):
31
+ """
32
+ Runs prediction on the provided audio file path.
33
+ """
34
+ # Load audio file
35
+ speech_array, _sampling_rate = torchaudio.load(path)
36
+ # Resample if necessary
37
+ if _sampling_rate != self.sampling_rate:
38
+ resampler = torchaudio.transforms.Resample(_sampling_rate, self.sampling_rate)
39
+ speech_array = resampler(speech_array)
40
+ speech_array = speech_array.squeeze().numpy()
41
+
42
+ # Preprocess audio input
43
+ inputs = self.processor(speech_array, sampling_rate=self.sampling_rate, return_tensors="pt", padding=True)
44
+ input_values = inputs.input_values.to('cuda' if torch.cuda.is_available() else 'cpu')
45
+ attention_mask = inputs.attention_mask.to('cuda' if torch.cuda.is_available() else 'cpu')
46
+
47
+ # Model inference
48
+ with torch.no_grad():
49
+ logits = self.model(input_values, attention_mask=attention_mask).logits
50
+
51
+ # Postprocessing
52
+ scores = softmax(logits, dim=1).detach().cpu().numpy()[0]
53
+ predictions = [{"label": self.config.id2label[i], "score": float(score)} for i, score in enumerate(scores)]
54
+ return predictions
55
+