File size: 2,879 Bytes
c9b6cb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import whisper
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

from chatbot import Chatbot
from utils.models_and_path import WHISPER_MODEL_NAME


class WhisperChatbot(Chatbot):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.whisper_model = whisper.load_model(WHISPER_MODEL_NAME)
        self._load_translation_engine()


    def response(self, audio):
        self._clean_audio()
        self._load_audio(audio)
        self._process_audio()

        en_result = super().response(self.text)
        if self.lang != "en":
            result_translated = self._translate_text(text=en_result, source="en", target=self.lang)['text']
        else:
            result_translated = en_result

        return self.transcribed_text, self.text, self.lang, en_result, result_translated


    def _load_translation_engine(self):
        self.translation_prompt = PromptTemplate(
            input_variables=["source", "target", "text"],
            template="Translate from language {source} to {target}: {text}?",
        )
        self.translation_chain = LLMChain(llm=self.LLM, prompt=self.translation_prompt)


    def _load_audio(self, audio):
        # assert isinstance(audio, bytes), "Audio must be bytes"
        # assert self.whisper_model, "Whisper model not loaded"

        # load audio and pad/trim it to fit 30 seconds
        self.audio = whisper.pad_or_trim(
            whisper.load_audio(audio)
        )


    def _process_audio(self):
        # assert self.audio, "Audio not loaded"
        # assert self.whisper_model, "Whisper model not loaded"

        # Make log-Mel spectrogram and move to the same device as the model
        mel = whisper.log_mel_spectrogram(self.audio).to(self.whisper_model.device)

        # Detcet language
        _, probas = self.whisper_model.detect_language(mel)
        self.lang = max(probas, key=probas.get)

        # Decode the audio
        options = whisper.DecodingOptions(fp16=False)
        self.transcribed_text = whisper.decode(self.whisper_model, mel, options).text

        # Check the language of the audio;
        # if it's english, use the transcribed text as is
        # else, translate it to english
        if self.lang == "en":
            self.text = self.transcribed_text
        else:
            # translate from detected lang to en
            self.text = self._translate_text(
                text=self.transcribed_text,
                source=self.lang,
                target="en"
            )['text']


    def _translate_text(self, text, source, target):
        return self.translation_chain({
            "source": source,
            "target": target,
            "text": text
        })


    def _clean_audio(self):
        self.audio = None
        self.lang = None
        self.text = None
        self.transcribed_text = None