File size: 5,229 Bytes
a647c50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import logging
import torch
import base64
import os

from indexify_extractor_sdk import Content, Extractor, Feature
from pyannote.audio import Pipeline
from transformers import pipeline, AutoModelForCausalLM
from .diarization_utils import diarize
from huggingface_hub import HfApi
from starlette.exceptions import HTTPException

from pydantic import BaseModel
from pydantic_settings import BaseSettings
from typing import Optional, Literal, List, Union

logger = logging.getLogger(__name__)
token = os.getenv('HF_TOKEN')

class ModelSettings(BaseSettings):
    asr_model: str = "openai/whisper-large-v3"
    assistant_model: Optional[str] = "distil-whisper/distil-large-v3"
    diarization_model: Optional[str] = "pyannote/speaker-diarization-3.1"
    hf_token: Optional[str] = token

model_settings = ModelSettings()

class ASRExtractorConfig(BaseModel):
    task: Literal["transcribe", "translate"] = "transcribe"
    batch_size: int = 24
    assisted: bool = False
    chunk_length_s: int = 30
    sampling_rate: int = 16000
    language: Optional[str] = None
    num_speakers: Optional[int] = None
    min_speakers: Optional[int] = None
    max_speakers: Optional[int] = None

class ASRExtractor(Extractor):
    name = "tensorlake/asrdiarization"
    description = "Powerful ASR + diarization + speculative decoding."
    system_dependencies = ["ffmpeg"]
    input_mime_types = ["audio", "audio/mpeg"]

    def __init__(self):
        super(ASRExtractor, self).__init__()

        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        logger.info(f"Using device: {device.type}")
        torch_dtype = torch.float32 if device.type == "cpu" else torch.float16

        self.assistant_model = AutoModelForCausalLM.from_pretrained(
            model_settings.assistant_model,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=True,
            use_safetensors=True
        ) if model_settings.assistant_model else None

        if self.assistant_model:
            self.assistant_model.to(device)

        self.asr_pipeline = pipeline(
            "automatic-speech-recognition",
            model=model_settings.asr_model,
            torch_dtype=torch_dtype,
            device=device
        )

        if model_settings.diarization_model:
            # diarization pipeline doesn't raise if there is no token
            HfApi().whoami(model_settings.hf_token)
            self.diarization_pipeline = Pipeline.from_pretrained(
                checkpoint_path=model_settings.diarization_model,
                use_auth_token=model_settings.hf_token,
            )
            self.diarization_pipeline.to(device)
        else:
            self.diarization_pipeline = None

    def extract(self, content: Content, params: ASRExtractorConfig) -> List[Union[Feature, Content]]:
        file = base64.b64decode(content.data)
        logger.info(f"inference params: {params}")

        generate_kwargs = {
            "task": params.task, 
            "language": params.language,
            "assistant_model": self.assistant_model if params.assisted else None
        }

        try:
            asr_outputs = self.asr_pipeline(
                file,
                chunk_length_s=params.chunk_length_s,
                batch_size=params.batch_size,
                generate_kwargs=generate_kwargs,
                return_timestamps=True,
            )
        except RuntimeError as e:
            logger.error(f"ASR inference error: {str(e)}")
            raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
        except Exception as e:
            logger.error(f"Unknown error diring ASR inference: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}")

        if self.diarization_pipeline:
            try:
                transcript = diarize(self.diarization_pipeline, file, params, asr_outputs)
            except RuntimeError as e:
                logger.error(f"Diarization inference error: {str(e)}")
                raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
            except Exception as e:
                logger.error(f"Unknown error during diarization: {str(e)}")
                raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
        else:
            transcript = []

        feature = Feature.metadata(value={"chunks": asr_outputs["chunks"], "text": asr_outputs["text"]})
        return [Content.from_text(str(transcript), features=[feature])]
    
    def sample_input(self) -> Content:
        filepath = "sample.mp3"
        with open(filepath, 'rb') as f:
            audio_encoded = base64.b64encode(f.read()).decode("utf-8")
        return Content(content_type="audio/mpeg", data=audio_encoded)

if __name__ == "__main__":
    filepath = "sample.mp3"
    with open(filepath, 'rb') as f:
        audio_encoded = base64.b64encode(f.read()).decode("utf-8")
    data = Content(content_type="audio/mpeg", data=audio_encoded)
    params = ASRExtractorConfig(batch_size=24)
    extractor = ASRExtractor()
    results = extractor.extract(data, params=params)
    print(results)