sanchit-gandhi HF staff commited on
Commit
071c26a
1 Parent(s): d7fec45

Update asr_diarizer.py

Browse files
Files changed (1) hide show
  1. asr_diarizer.py +14 -3
asr_diarizer.py CHANGED
@@ -9,19 +9,30 @@ from transformers import pipeline
9
  class ASRDiarizationPipeline:
10
  def __init__(
11
  self,
 
 
 
 
 
 
 
 
 
12
  asr_model: Optional[str] = "openai/whisper-small",
13
  diarizer_model: Optional[str] = "pyannote/speaker-diarization",
14
  chunk_length_s: int = 30,
 
15
  **kwargs,
16
  ):
17
- self.asr_pipeline = pipeline(
18
  "automatic-speech-recognition",
19
  model=asr_model,
20
- use_auth_token=True,
21
  chunk_length_s=chunk_length_s,
 
22
  **kwargs,
23
  )
24
- self.diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=True)
 
25
 
26
  def __call__(
27
  self,
 
9
  class ASRDiarizationPipeline:
10
  def __init__(
11
  self,
12
+ asr_pipeline,
13
+ diarization_pipeline,
14
+ ):
15
+ self.asr_pipeline = asr_pipeline
16
+ self.diarization_pipeline = diarization_pipeline
17
+
18
+ @classmethod
19
+ def from_pretrained(
20
+ cls,
21
  asr_model: Optional[str] = "openai/whisper-small",
22
  diarizer_model: Optional[str] = "pyannote/speaker-diarization",
23
  chunk_length_s: int = 30,
24
+ use_auth_token: Union[str, bool] = True,
25
  **kwargs,
26
  ):
27
+ asr_pipeline = pipeline(
28
  "automatic-speech-recognition",
29
  model=asr_model,
 
30
  chunk_length_s=chunk_length_s,
31
+ use_auth_token=use_auth_token,
32
  **kwargs,
33
  )
34
+ diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token)
35
+ cls(asr_pipeline, diarization_pipeline)
36
 
37
  def __call__(
38
  self,