realtimespeech / models.py
DiegoLigtenberg's picture
Add requirements file
e711356
raw
history blame
4.99 kB
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, pipeline
from pydub import AudioSegment
import whisper
from settings import MODEL_PARSER
from pytube import YouTube
class BagOfModels:
'''model -> is a model from hugging face
model_names -> modelnames that can be chosen from in streamlit
model_settinsg -> settings of model that can be customized by user
'''
args = MODEL_PARSER
barfs = 5
def __init__(self,model,model_names,model_settings,model_tasks, **kwargs):
self.model = model
self.model_names = model_names
self.model_settings = model_settings
self.model_tasks = model_tasks
self.kwargs = kwargs
@classmethod
def get_model_settings(cls):
bag_of_models = BagOfModels(**vars(cls.args))
return bag_of_models.model_settings
@classmethod
def get_model_names(cls):
bag_of_models = BagOfModels(**vars(cls.args))
return bag_of_models.model_names
@classmethod
def get_model(cls):
bag_of_models = BagOfModels(**vars(cls.args))
return bag_of_models.model
@classmethod
def get_model_tasks(cls):
bag_of_models = BagOfModels(**vars(cls.args))
return bag_of_models.model_tasks
@classmethod
def load_model(cls,model_name,**kwargs):
bag_of_models = BagOfModels(**vars(cls.args))
cls.model = bag_of_models.model
assert model_name in bag_of_models.model_names, f"please pick one of the available models: {bag_of_models.model_names}"
return Model(model_name,**cls.model[model_name])
class Model:
def __init__(self,model_name,task,url,**kwargs):
self.url = url
self.model_name = model_name
self.name = self.url.split("https://huggingface.co/")[1]
self.task = task
self.kwargs = kwargs
self.init_optional_args(**self.kwargs)
def init_optional_args(self,year=None,description=None):
self._year = year
self._description = description
def predict_stt(self,source,source_type,model_task):
model = whisper.load_model(self.model_name.split("_")[1]) #tiny - base - medium
stt = SoundToText(source,source_type,model_task,model=model,tokenizer=None)
stt.whisper()
return stt
def predict_summary(self):
tokenizer = Wav2Vec2Processor.from_pretrained(self.name)
model = Wav2Vec2ForCTC.from_pretrained(self.name) # Note: PyTorch Model
class Transcription():
def __init__(self,model,source,source_type) -> None:
pass
class SoundToText():
def __init__(self,source,source_type,model_task,model,tokenizer=None):
self.source = source
self.source_type = source_type
self.model = model
self.model_task = model_task
self.tokenizer = tokenizer
def wav2vec(self,size):
pass
def wav2vec2(self,size):
pass
def whisper(self):
# download youtube url
if self.source_type == "YouTube":
self.audio_path = YouTube(self.source).streams.get_by_itag(140).download("output/", filename="audio")
if self.source_type == "File":
audio = None
if self.source.name.endswith('.wav'): audio = AudioSegment.from_wav(self.source)
elif self.source.name.endswith('.mp3'): audio = AudioSegment.from_mp3(self.source)
audio.export('output/audio.wav', format='wav')
self.audio_path = "output/audio.wav"
model = whisper.load_model("base")
self.raw_output = model.transcribe(self.audio_path,verbose=True)
self.text = self.raw_output["text"]
self.language = self.raw_output["language"]
self.segments = self.raw_output["segments"]
# Remove token ids from the output
for segment in self.segments:
del segment["tokens"]
self.transcribed = True
class TextToSummary():
def __init__(self,input_text,min_length,max_length):
self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
self.summary_input = input_text
self.summary_output = (self.summarizer(self.summary_input, min_length=min_length, max_length=max_length, do_sample=False))
def get_summary(self):
return self.summary_output
def wav2vec(self):
pass
def record(model_name):
args = MODEL_PARSER
models = BagOfModels.get_model_names()
tasks = BagOfModels.get_model_tasks()
whisper_base = BagOfModels.load_model(model_name,**vars(args))
whisper_base.predict()
if __name__== "__main__":
args = MODEL_PARSER
models = BagOfModels.get_model_names()
tasks = BagOfModels.get_model_tasks()
whisper_base = BagOfModels.load_model("whisper_base",**vars(args))
whisper_base.predict_stt()