|
import whisper |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import LLMChain |
|
|
|
from chatbot import Chatbot |
|
from utils.models_and_path import WHISPER_MODEL_NAME |
|
|
|
|
|
class WhisperChatbot(Chatbot): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.whisper_model = whisper.load_model(WHISPER_MODEL_NAME) |
|
self._load_translation_engine() |
|
|
|
|
|
def response(self, audio): |
|
self._clean_audio() |
|
self._load_audio(audio) |
|
self._process_audio() |
|
|
|
en_result = super().response(self.text) |
|
if self.lang != "en": |
|
result_translated = self._translate_text(text=en_result, source="en", target=self.lang)['text'] |
|
else: |
|
result_translated = en_result |
|
|
|
return self.transcribed_text, self.text, self.lang, en_result, result_translated |
|
|
|
|
|
def _load_translation_engine(self): |
|
self.translation_prompt = PromptTemplate( |
|
input_variables=["source", "target", "text"], |
|
template="Translate from language {source} to {target}: {text}?", |
|
) |
|
self.translation_chain = LLMChain(llm=self.LLM, prompt=self.translation_prompt) |
|
|
|
|
|
def _load_audio(self, audio): |
|
|
|
|
|
|
|
|
|
self.audio = whisper.pad_or_trim( |
|
whisper.load_audio(audio) |
|
) |
|
|
|
|
|
def _process_audio(self): |
|
|
|
|
|
|
|
|
|
mel = whisper.log_mel_spectrogram(self.audio).to(self.whisper_model.device) |
|
|
|
|
|
_, probas = self.whisper_model.detect_language(mel) |
|
self.lang = max(probas, key=probas.get) |
|
|
|
|
|
options = whisper.DecodingOptions(fp16=False) |
|
self.transcribed_text = whisper.decode(self.whisper_model, mel, options).text |
|
|
|
|
|
|
|
|
|
if self.lang == "en": |
|
self.text = self.transcribed_text |
|
else: |
|
|
|
self.text = self._translate_text( |
|
text=self.transcribed_text, |
|
source=self.lang, |
|
target="en" |
|
)['text'] |
|
|
|
|
|
def _translate_text(self, text, source, target): |
|
return self.translation_chain({ |
|
"source": source, |
|
"target": target, |
|
"text": text |
|
}) |
|
|
|
|
|
def _clean_audio(self): |
|
self.audio = None |
|
self.lang = None |
|
self.text = None |
|
self.transcribed_text = None |