spellingdragon commited on
Commit
acdd9da
1 Parent(s): 171330f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -21
handler.py CHANGED
@@ -5,31 +5,31 @@ from transformers import WhisperProcessor, AutoModelForSpeechSeq2Seq, AutoProces
5
 
6
  class EndpointHandler():
7
  def __init__(self, path=""):
8
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
- model_id = path
11
 
12
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
13
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
14
- )
15
- model.to(device)
16
 
17
- processor = AutoProcessor.from_pretrained(model_id)
18
  #processor = WhisperProcessor.from_pretrained(model_id)
19
 
20
- self.pipeline = pipeline(
21
- "automatic-speech-recognition",
22
- model=model,
23
- tokenizer=processor.tokenizer,
24
- feature_extractor=processor.feature_extractor,
25
- max_new_tokens=128,
26
- chunk_length_s=30,
27
- batch_size=16,
28
- return_timestamps=True,
29
- torch_dtype=torch_dtype,
30
- device=device,
31
- )
32
- self.model = model
33
 
34
 
35
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
 
5
 
6
  class EndpointHandler():
7
  def __init__(self, path=""):
8
+ #device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
+ #torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
+ #model_id = path
11
 
12
+ #model = AutoModelForSpeechSeq2Seq.from_pretrained(
13
+ # model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
14
+ #)
15
+ #model.to(device)
16
 
17
+ #processor = AutoProcessor.from_pretrained(model_id)
18
  #processor = WhisperProcessor.from_pretrained(model_id)
19
 
20
+ #self.pipeline = pipeline(
21
+ # "automatic-speech-recognition",
22
+ # model=model,
23
+ # tokenizer=processor.tokenizer,
24
+ # feature_extractor=processor.feature_extractor,
25
+ # max_new_tokens=128,
26
+ # chunk_length_s=30,
27
+ # batch_size=16,
28
+ # return_timestamps=True,
29
+ # torch_dtype=torch_dtype,
30
+ # device=device,
31
+ #)
32
+ #self.model = model
33
 
34
 
35
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]: