# imports import gradio as gr import json import librosa import os import soundfile as sf import tempfile import uuid import torch from transformers import AutoTokenizer, VitsModel, set_seed, AutoModelForCausalLM, AutoTokenizer, pipeline from nemo.collections.asr.models import ASRModel from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED import time torch.random.manual_seed(0) proc_model_name = "microsoft/Phi-3-mini-4k-instruct" proc_model = AutoModelForCausalLM.from_pretrained( proc_model_name, torch_dtype=torch.float16, trust_remote_code=True, attn_implementation='eager', revision='300945e90b6f55d3cb88261c8e5333fae696f672', ) proc_model.to("cpu") proc_tokenizer = AutoTokenizer.from_pretrained(proc_model_name) SAMPLE_RATE = 16000 # Hz MAX_AUDIO_MINUTES = 10 # wont try to transcribe if longer than this model = ASRModel.from_pretrained("nvidia/canary-1b") model.eval() # make sure beam size always 1 for consistency model.change_decoding_strategy(None) decoding_cfg = model.cfg.decoding decoding_cfg.beam.beam_size = 1 model.change_decoding_strategy(decoding_cfg) vits_model_name = "facebook/mms-tts-eng" vits_model = VitsModel.from_pretrained(vits_model_name) vits_tokenizer = AutoTokenizer.from_pretrained(vits_model_name) set_seed(555) def text_to_speech(text_response): inputs = vits_tokenizer(text=text_response, return_tensors="pt") with torch.no_grad(): outputs = vits_model(**inputs) waveform = outputs.waveform[0] sf.write('output.wav', waveform.numpy(), vits_model.config.sampling_rate) return 'output.wav' def convert_audio(audio_filepath, tmpdir, utt_id): data, sr = librosa.load(audio_filepath, sr=None, mono=True) duration = librosa.get_duration(y=data, sr=sr) if sr != SAMPLE_RATE: data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) out_filename = os.path.join(tmpdir, utt_id + '.wav') # save output audio sf.write(out_filename, data, SAMPLE_RATE) return out_filename, duration def transcribe(audio_filepath): print(audio_filepath) time.sleep(2) if audio_filepath is None: raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone. \nIf the microphone already has audio, please wait a few moments for it to upload properly") utt_id = uuid.uuid4() with tempfile.TemporaryDirectory() as tmpdir: converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id)) # make manifest file and save manifest_data = { "audio_filepath": converted_audio_filepath, "source_lang": "en", "target_lang": "en", "taskname": "asr", "pnc": "yes", "answer": "predict", "duration": str(duration), } manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json') with open(manifest_filepath, 'w') as fout: line = json.dumps(manifest_data) fout.write(line + '\n') output_text = model.transcribe(manifest_filepath)[0] return output_text start = {"role": "system", "content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user."} def generate_response(user_input): messages = [start, {"role": "user", "content": user_input}] inputs = proc_tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", ) with torch.no_grad(): outputs = proc_model.generate( inputs, max_new_tokens=100, ) response = proc_tokenizer.batch_decode( outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] return response def CanaryPhiVits(user_voice): user_input = transcribe(user_voice) print("user_input:") print(user_input) response = generate_response(user_input) if response.startswith(user_input): response = response.replace(user_input, '', 1) print("chatty_response:") print(response) chatty_response = text_to_speech(response) return chatty_response # Create a Gradio interface iface = gr.Interface( fn=CanaryPhiVits, title="Chatty Ashe", #theme="gstaff/xkcd", inputs=gr.Audio( sources=["microphone", "upload"], label="Input Audio", type="filepath", format="wav", ), outputs=gr.Audio( label="Output Audio" ), ) # Launch the interface iface.queue() iface.launch()