spellingdragon commited on
Commit
dd17b27
1 Parent(s): e9204c7

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -42
handler.py CHANGED
@@ -1,54 +1,21 @@
1
- from typing import Dict, List, Any
2
- import torch
3
- from transformers.pipelines.audio_utils import ffmpeg_read
4
- from transformers import WhisperProcessor, AutoModelForSpeechSeq2Seq, AutoProcessor, AutoTokenizer, pipeline
5
 
6
  class EndpointHandler():
7
  def __init__(self, path=""):
8
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
  model_id = "openai/whisper-large-v3"
11
-
12
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
13
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
14
- )
15
- model.to(device)
16
-
17
- processor = AutoProcessor.from_pretrained(model_id)
18
- #processor = WhisperProcessor.from_pretrained(model_id)
19
-
20
- self.pipeline = pipeline(
21
- "automatic-speech-recognition",
22
- model=model_id,
23
- tokenizer=processor.tokenizer,
24
- feature_extractor=processor.feature_extractor,
25
- max_new_tokens=128,
26
- chunk_length_s=30,
27
- batch_size=16,
28
- return_timestamps=True,
29
- torch_dtype=torch_dtype,
30
- device=device,
31
- )
32
- self.model = model
33
-
34
 
35
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
36
- """
37
- Args:
38
- data (:obj:):
39
- includes the input data and the parameters for the inference.
40
- Return:
41
- A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
42
- - "label": A string representing what the label/class is. There can be multiple labels.
43
- - "score": A score between 0 and 1 describing how confident the model is for this label/class.
44
- """
45
  inputs = data.pop("inputs", data)
46
  parameters = data.pop("parameters", None)
47
 
48
- # pass inputs with all kwargs in data
49
  if parameters is not None:
50
- result = self.pipeline(inputs, return_timestamps=True, **parameters)
51
  else:
52
- result = self.pipeline(inputs, return_timestamps=True, generate_kwargs={"task": "translate"})
53
- # postprocess the prediction
54
  return {"chunks": result["chunks"]}
 
1
+ from typing import Dict, Any
2
+ from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline
 
 
3
 
4
  class EndpointHandler():
5
  def __init__(self, path=""):
 
 
6
  model_id = "openai/whisper-large-v3"
7
+ task = "automatic-speech-recognition"
8
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ self.pipeline = pipeline(task, model=self.model, tokenizer=self.tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
 
 
 
 
 
 
 
 
 
13
  inputs = data.pop("inputs", data)
14
  parameters = data.pop("parameters", None)
15
 
 
16
  if parameters is not None:
17
+ result = self.pipeline(inputs, return_timestamps=True, **parameters)
18
  else:
19
+ result = self.pipeline(inputs, return_timestamps=True)
20
+
21
  return {"chunks": result["chunks"]}