gautamtata commited on
Commit
c645001
1 Parent(s): 96fd59e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +40 -45
handler.py CHANGED
@@ -1,55 +1,50 @@
1
- import torchaudio
 
2
  import torch
3
- from transformers import Wav2Vec2Processor, Wav2Vec2ForSpeechClassification, AutoConfig
4
- from torch.nn.functional import softmax
5
  from typing import Dict, List, Any
6
 
7
- # Suppose this handler is for a speech classification model
 
8
  class EndpointHandler():
9
- def __init__(self, path=""):
10
- # Assuming that the path contains all the necessary files for model and processor.
11
- config = AutoConfig.from_pretrained(path)
12
- self.processor = Wav2Vec2Processor.from_pretrained(path)
13
- self.model = Wav2Vec2ForSpeechClassification.from_pretrained(path)
14
- self.sampling_rate = self.processor.feature_extractor.sampling_rate
15
- self.model.to('cuda' if torch.cuda.is_available() else 'cpu')
16
 
17
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
- """
19
- Overriding call method to handle speech input and return classification result.
20
- """
21
- # Extract 'inputs' key from the data dictionary. This should be a path to the audio file.
22
- audio_path = data.get('inputs', None)
23
- if audio_path is None:
24
- raise ValueError("Invalid input, 'inputs' key with path to the audio file is required.")
25
-
26
- # Load and preprocess the audio file, and run prediction
27
- outputs = self.predict(audio_path)
28
- return outputs
29
 
30
  def predict(self, path):
31
- """
32
- Runs prediction on the provided audio file path.
33
- """
34
- # Load audio file
35
- speech_array, _sampling_rate = torchaudio.load(path)
36
- # Resample if necessary
37
- if _sampling_rate != self.sampling_rate:
38
- resampler = torchaudio.transforms.Resample(_sampling_rate, self.sampling_rate)
39
- speech_array = resampler(speech_array)
40
- speech_array = speech_array.squeeze().numpy()
41
-
42
- # Preprocess audio input
43
- inputs = self.processor(speech_array, sampling_rate=self.sampling_rate, return_tensors="pt", padding=True)
44
- input_values = inputs.input_values.to('cuda' if torch.cuda.is_available() else 'cpu')
45
- attention_mask = inputs.attention_mask.to('cuda' if torch.cuda.is_available() else 'cpu')
46
-
47
- # Model inference
48
  with torch.no_grad():
49
  logits = self.model(input_values, attention_mask=attention_mask).logits
50
-
51
- # Postprocessing
52
- scores = softmax(logits, dim=1).detach().cpu().numpy()[0]
53
- predictions = [{"label": self.config.id2label[i], "score": float(score)} for i, score in enumerate(scores)]
54
- return predictions
 
 
 
 
 
55
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, Wav2Vec2Processor, Wav2Vec2ForSpeechClassification
2
+ from torch import nn
3
  import torch
4
+ import torchaudio
5
+ import torch.nn.functional as F
6
  from typing import Dict, List, Any
7
 
8
+ # Assuming the provided predict and related functions are part of your handler
9
+
10
  class EndpointHandler():
11
+ def __init__(self, model_path=""):
12
+ # Here we load the model and processor.
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ self.config = AutoConfig.from_pretrained(f"{model_path}/config.json")
15
+ self.processor = Wav2Vec2Processor.from_pretrained(model_path)
16
+ self.model = Wav2Vec2ForSpeechClassification.from_pretrained(model_path).to(self.device)
 
17
 
18
+ def speech_file_to_array_fn(self, path):
19
+ sampling_rate = self.processor.feature_extractor.sampling_rate
20
+ speech_array, _sampling_rate = torchaudio.load(path)
21
+ resampler = torchaudio.transforms.Resample(_sampling_rate, sampling_rate)
22
+ speech = resampler(speech_array).squeeze().numpy()
23
+ return speech
 
 
 
 
 
 
24
 
25
  def predict(self, path):
26
+ speech = self.speech_file_to_array_fn(path)
27
+ features = self.processor(speech, sampling_rate=self.processor.feature_extractor.sampling_rate,
28
+ return_tensors="pt", padding=True)
29
+
30
+ input_values = features.input_values.to(self.device)
31
+ attention_mask = features.attention_mask.to(self.device)
32
+
 
 
 
 
 
 
 
 
 
 
33
  with torch.no_grad():
34
  logits = self.model(input_values, attention_mask=attention_mask).logits
35
+ scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
36
+ outputs = [{"label": self.config.id2label[i], "score": score} for i, score in enumerate(scores)]
37
+ return outputs
38
+
39
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
40
+ """
41
+ The actual method called during inference. Expects data to have a 'path' to the audio file.
42
+ """
43
+ # Get the path to the audio file from the request data
44
+ path = data.get("path")
45
 
46
+ # If the path is provided, we run the prediction, else return an error message
47
+ if path:
48
+ return self.predict(path)
49
+ else:
50
+ return {"error": "Path to the audio file is required."}