TTS-STT-Blocks / app.py
awacke1's picture
Update app.py
bfbd9e3
raw
history blame
4.97 kB
import streamlit as st
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import datetime
from transformers import pipeline
import gradio as gr
import tempfile
from typing import Optional
import numpy as np
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
@st.experimental_singleton
def get_db_firestore():
cred = credentials.Certificate('test.json')
firebase_admin.initialize_app(cred, {'projectId': u'clinical-nlp-b9117',})
db = firestore.client()
return db
db = get_db_firestore()
asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h")
MODEL_NAMES = [
# "en/ek1/tacotron2",
"en/ljspeech/tacotron2-DDC",
# "en/ljspeech/tacotron2-DDC_ph",
# "en/ljspeech/glow-tts",
# "en/ljspeech/tacotron2-DCA",
# "en/ljspeech/speedy-speech-wn",
# "en/ljspeech/vits",
# "en/vctk/sc-glow-tts",
# "en/vctk/vits",
# "en/sam/tacotron-DDC",
# "es/mai/tacotron2-DDC",
"fr/mai/tacotron2-DDC",
"zh-CN/baker/tacotron2-DDC-GST",
"nl/mai/tacotron2-DDC",
"de/thorsten/tacotron2-DCA",
# "ja/kokoro/tacotron2-DDC",
]
MODELS = {}
manager = ModelManager()
for MODEL_NAME in MODEL_NAMES:
print(f"downloading {MODEL_NAME}")
model_path, config_path, model_item = manager.download_model(f"tts_models/{MODEL_NAME}")
vocoder_name: Optional[str] = model_item["default_vocoder"]
vocoder_path = None
vocoder_config_path = None
if vocoder_name is not None:
vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name)
synthesizer = Synthesizer(
model_path, config_path, None, vocoder_path, vocoder_config_path,
)
MODELS[MODEL_NAME] = synthesizer
def transcribe(audio):
text = asr(audio)["text"]
return text
classifier = pipeline("text-classification")
def speech_to_text(speech):
text = asr(speech)["text"]
return text
def text_to_sentiment(text):
sentiment = classifier(text)[0]["label"]
return sentiment
def upsert(text):
date_time =str(datetime.datetime.today())
doc_ref = db.collection('Text2SpeechSentimentSave').document(date_time)
doc_ref.set({u'firefield': 'Recognize Speech', u'first': 'https://huggingface.co/spaces/awacke1/Text2SpeechSentimentSave', u'last': text, u'born': date_time,})
saved = select('Text2SpeechSentimentSave', date_time)
# check it here: https://console.firebase.google.com/u/0/project/clinical-nlp-b9117/firestore/data/~2FStreamlitSpaces
return saved
def select(collection, document):
doc_ref = db.collection(collection).document(document)
doc = doc_ref.get()
docid = ("The id is: ", doc.id)
contents = ("The contents are: ", doc.to_dict())
return contents
def selectall(text):
docs = db.collection('Text2SpeechSentimentSave').stream()
doclist=''
for doc in docs:
#docid=doc.id
#dict=doc.to_dict()
#doclist+=doc.to_dict()
r=(f'{doc.id} => {doc.to_dict()}')
doclist += r
return doclist
def tts(text: str, model_name: str):
print(text, model_name)
synthesizer = MODELS.get(model_name, None)
if synthesizer is None:
raise NameError("model not found")
wavs = synthesizer.tts(text)
# output = (synthesizer.output_sample_rate, np.array(wavs))
# return output
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
synthesizer.save_wav(wavs, fp)
return fp.name
demo = gr.Blocks()
with demo:
#audio_file = gr.Audio(type="filepath")
audio_file = gr.inputs.Audio(source="microphone", type="filepath")
text = gr.Textbox()
label = gr.Label()
saved = gr.Textbox()
savedAll = gr.Textbox()
TTSchoice = gr.inputs.Radio( label="Pick a TTS Model", choices=MODEL_NAMES, )
audio_file_out = gr.Audio(label="Generated Speech")
b1 = gr.Button("Recognize Speech")
b2 = gr.Button("Classify Sentiment")
b3 = gr.Button("Save Speech to Text")
b4 = gr.Button("Retrieve All")
b5 = gr.Button("Read It Back Aloud")
b1.click(speech_to_text, inputs=audio_file, outputs=text)
b2.click(text_to_sentiment, inputs=text, outputs=label)
b3.click(upsert, inputs=text, outputs=saved)
b4.click(selectall, inputs=text, outputs=savedAll)
b5.click(tts, inputs={text,TTSchoice}, outputs=audio_file_out)
demo.launch(share=True)
#iface = gr.Interface(
# fn=tts,
# inputs=[
# gr.inputs.Textbox( label="Input", default="Hello, how are you?", ),
# gr.inputs.Radio( label="Pick a TTS Model", choices=MODEL_NAMES, ),
# ],
# outputs=gr.outputs.Audio(label="Output"),
# title="🐸💬 - Coqui TTS",
# theme="huggingface",
# description="🐸💬 - a deep learning toolkit for Text-to-Speech, battle-tested in research and production",
# article="more info at https://github.com/coqui-ai/TTS",
#)
#iface.launch()