osanseviero HF staff commited on
Commit
5913155
1 Parent(s): 46d31bb

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +5 -3
pipeline.py CHANGED
@@ -9,9 +9,9 @@ class PreTrainedPipeline():
9
  """
10
  Initialize model
11
  """
12
- processor = Wav2Vec2Processor.from_pretrained(path)
13
- model = Wav2Vec2ForCTC.from_pretrained(path)
14
- vocab_list = list(processor.tokenizer.get_vocab().keys())
15
 
16
  # convert ctc blank character representation
17
  vocab_list[0] = ""
@@ -39,6 +39,8 @@ class PreTrainedPipeline():
39
  A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
40
  the detected text from the input audio.
41
  """
 
 
42
  return {
43
  "text": self.decoder.decode(logits)
44
  }
 
9
  """
10
  Initialize model
11
  """
12
+ self.processor = Wav2Vec2Processor.from_pretrained(path)
13
+ self.model = Wav2Vec2ForCTC.from_pretrained(path)
14
+ vocab_list = list(self.processor.tokenizer.get_vocab().keys())
15
 
16
  # convert ctc blank character representation
17
  vocab_list[0] = ""
 
39
  A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
40
  the detected text from the input audio.
41
  """
42
+ input_values = self.processor(arr, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
43
+ logits = self.model(input_values).logits.cpu().detach().numpy()[0]
44
  return {
45
  "text": self.decoder.decode(logits)
46
  }