Chatty_Ashe / app.py
gdnartea's picture
Update app.py
e983f3f verified
# imports
import gradio as gr
import json
import librosa
import os
import soundfile as sf
import tempfile
import uuid
import torch
from transformers import AutoTokenizer, VitsModel, set_seed, AutoModelForCausalLM, AutoTokenizer, pipeline
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
import time
torch.random.manual_seed(0)
proc_model_name = "microsoft/Phi-3-mini-4k-instruct"
proc_model = AutoModelForCausalLM.from_pretrained(
proc_model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
attn_implementation='eager',
revision='300945e90b6f55d3cb88261c8e5333fae696f672',
)
proc_model.to("cpu")
proc_tokenizer = AutoTokenizer.from_pretrained(proc_model_name)
SAMPLE_RATE = 16000 # Hz
MAX_AUDIO_MINUTES = 10 # wont try to transcribe if longer than this
model = ASRModel.from_pretrained("nvidia/canary-1b")
model.eval()
# make sure beam size always 1 for consistency
model.change_decoding_strategy(None)
decoding_cfg = model.cfg.decoding
decoding_cfg.beam.beam_size = 1
model.change_decoding_strategy(decoding_cfg)
vits_model_name = "facebook/mms-tts-eng"
vits_model = VitsModel.from_pretrained(vits_model_name)
vits_tokenizer = AutoTokenizer.from_pretrained(vits_model_name)
set_seed(555)
def text_to_speech(text_response):
inputs = vits_tokenizer(text=text_response, return_tensors="pt")
with torch.no_grad():
outputs = vits_model(**inputs)
waveform = outputs.waveform[0]
sf.write('output.wav', waveform.numpy(), vits_model.config.sampling_rate)
return 'output.wav'
def convert_audio(audio_filepath, tmpdir, utt_id):
data, sr = librosa.load(audio_filepath, sr=None, mono=True)
duration = librosa.get_duration(y=data, sr=sr)
if sr != SAMPLE_RATE:
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
out_filename = os.path.join(tmpdir, utt_id + '.wav')
# save output audio
sf.write(out_filename, data, SAMPLE_RATE)
return out_filename, duration
def transcribe(audio_filepath):
print(audio_filepath)
time.sleep(2)
if audio_filepath is None:
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone. \nIf the microphone already has audio, please wait a few moments for it to upload properly")
utt_id = uuid.uuid4()
with tempfile.TemporaryDirectory() as tmpdir:
converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
# make manifest file and save
manifest_data = {
"audio_filepath": converted_audio_filepath,
"source_lang": "en",
"target_lang": "en",
"taskname": "asr",
"pnc": "yes",
"answer": "predict",
"duration": str(duration),
}
manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
with open(manifest_filepath, 'w') as fout:
line = json.dumps(manifest_data)
fout.write(line + '\n')
output_text = model.transcribe(manifest_filepath)[0]
return output_text
start = {"role": "system", "content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user."}
def generate_response(user_input):
messages = [start, {"role": "user", "content": user_input}]
inputs = proc_tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
)
with torch.no_grad():
outputs = proc_model.generate(
inputs,
max_new_tokens=100,
)
response = proc_tokenizer.batch_decode(
outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return response
def CanaryPhiVits(user_voice):
user_input = transcribe(user_voice)
print("user_input:")
print(user_input)
response = generate_response(user_input)
if response.startswith(user_input):
response = response.replace(user_input, '', 1)
print("chatty_response:")
print(response)
chatty_response = text_to_speech(response)
return chatty_response
# Create a Gradio interface
iface = gr.Interface(
fn=CanaryPhiVits,
title="Chatty Ashe",
#theme="gstaff/xkcd",
inputs=gr.Audio(
sources=["microphone", "upload"],
label="Input Audio",
type="filepath",
format="wav",
),
outputs=gr.Audio(
label="Output Audio"
),
)
# Launch the interface
iface.queue()
iface.launch()