afurkank commited on
Commit
304c741
1 Parent(s): feea952

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -40
handler.py CHANGED
@@ -5,7 +5,6 @@ import base64
5
 
6
  from pyannote.audio import Pipeline
7
  from transformers import pipeline, AutoModelForCausalLM
8
- from diarization_utils import diarize
9
  from huggingface_hub import HfApi
10
  from pydantic import ValidationError
11
  from starlette.exceptions import HTTPException
@@ -22,16 +21,6 @@ class EndpointHandler():
22
  logger.info(f"Using device: {device.type}")
23
  torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
24
 
25
- self.assistant_model = AutoModelForCausalLM.from_pretrained(
26
- model_settings.assistant_model,
27
- torch_dtype=torch_dtype,
28
- low_cpu_mem_usage=True,
29
- use_safetensors=True
30
- ) if model_settings.assistant_model else None
31
-
32
- if self.assistant_model:
33
- self.assistant_model.to(device)
34
-
35
  self.asr_pipeline = pipeline(
36
  "automatic-speech-recognition",
37
  model=model_settings.asr_model,
@@ -39,18 +28,6 @@ class EndpointHandler():
39
  device=device
40
  )
41
 
42
- if model_settings.diarization_model:
43
- # diarization pipeline doesn't raise if there is no token
44
- HfApi().whoami(model_settings.hf_token)
45
- self.diarization_pipeline = Pipeline.from_pretrained(
46
- checkpoint_path=model_settings.diarization_model,
47
- use_auth_token=model_settings.hf_token,
48
- )
49
- self.diarization_pipeline.to(device)
50
- else:
51
- self.diarization_pipeline = None
52
-
53
-
54
  def __call__(self, inputs):
55
  file = inputs.pop("inputs")
56
  file = base64.b64decode(file)
@@ -65,8 +42,7 @@ class EndpointHandler():
65
 
66
  generate_kwargs = {
67
  "task": parameters.task,
68
- "language": parameters.language,
69
- "assistant_model": self.assistant_model if parameters.assisted else None
70
  }
71
 
72
  try:
@@ -81,23 +57,10 @@ class EndpointHandler():
81
  logger.error(f"ASR inference error: {str(e)}")
82
  raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
83
  except Exception as e:
84
- logger.error(f"Unknown error diring ASR inference: {str(e)}")
85
- raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}")
86
-
87
- if self.diarization_pipeline:
88
- try:
89
- transcript = diarize(self.diarization_pipeline, file, parameters, asr_outputs)
90
- except RuntimeError as e:
91
- logger.error(f"Diarization inference error: {str(e)}")
92
- raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
93
- except Exception as e:
94
- logger.error(f"Unknown error during diarization: {str(e)}")
95
- raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
96
- else:
97
- transcript = []
98
 
99
  return {
100
- "speakers": transcript,
101
  "chunks": asr_outputs["chunks"],
102
  "text": asr_outputs["text"],
103
  }
 
5
 
6
  from pyannote.audio import Pipeline
7
  from transformers import pipeline, AutoModelForCausalLM
 
8
  from huggingface_hub import HfApi
9
  from pydantic import ValidationError
10
  from starlette.exceptions import HTTPException
 
21
  logger.info(f"Using device: {device.type}")
22
  torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
23
 
 
 
 
 
 
 
 
 
 
 
24
  self.asr_pipeline = pipeline(
25
  "automatic-speech-recognition",
26
  model=model_settings.asr_model,
 
28
  device=device
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def __call__(self, inputs):
32
  file = inputs.pop("inputs")
33
  file = base64.b64decode(file)
 
42
 
43
  generate_kwargs = {
44
  "task": parameters.task,
45
+ "language": parameters.language
 
46
  }
47
 
48
  try:
 
57
  logger.error(f"ASR inference error: {str(e)}")
58
  raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
59
  except Exception as e:
60
+ logger.error(f"Unknown error during ASR inference: {str(e)}")
61
+ raise HTTPException(status_code=500, detail=f"Unknown error during ASR inference: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  return {
 
64
  "chunks": asr_outputs["chunks"],
65
  "text": asr_outputs["text"],
66
  }