multirag / voice_chat.py
gk8686's picture
Upload 15 files
c9b6cb4 verified
raw
history blame contribute delete
No virus
2.88 kB
import whisper
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from chatbot import Chatbot
from utils.models_and_path import WHISPER_MODEL_NAME
class WhisperChatbot(Chatbot):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.whisper_model = whisper.load_model(WHISPER_MODEL_NAME)
self._load_translation_engine()
def response(self, audio):
self._clean_audio()
self._load_audio(audio)
self._process_audio()
en_result = super().response(self.text)
if self.lang != "en":
result_translated = self._translate_text(text=en_result, source="en", target=self.lang)['text']
else:
result_translated = en_result
return self.transcribed_text, self.text, self.lang, en_result, result_translated
def _load_translation_engine(self):
self.translation_prompt = PromptTemplate(
input_variables=["source", "target", "text"],
template="Translate from language {source} to {target}: {text}?",
)
self.translation_chain = LLMChain(llm=self.LLM, prompt=self.translation_prompt)
def _load_audio(self, audio):
# assert isinstance(audio, bytes), "Audio must be bytes"
# assert self.whisper_model, "Whisper model not loaded"
# load audio and pad/trim it to fit 30 seconds
self.audio = whisper.pad_or_trim(
whisper.load_audio(audio)
)
def _process_audio(self):
# assert self.audio, "Audio not loaded"
# assert self.whisper_model, "Whisper model not loaded"
# Make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(self.audio).to(self.whisper_model.device)
# Detcet language
_, probas = self.whisper_model.detect_language(mel)
self.lang = max(probas, key=probas.get)
# Decode the audio
options = whisper.DecodingOptions(fp16=False)
self.transcribed_text = whisper.decode(self.whisper_model, mel, options).text
# Check the language of the audio;
# if it's english, use the transcribed text as is
# else, translate it to english
if self.lang == "en":
self.text = self.transcribed_text
else:
# translate from detected lang to en
self.text = self._translate_text(
text=self.transcribed_text,
source=self.lang,
target="en"
)['text']
def _translate_text(self, text, source, target):
return self.translation_chain({
"source": source,
"target": target,
"text": text
})
def _clean_audio(self):
self.audio = None
self.lang = None
self.text = None
self.transcribed_text = None