|
import gradio as gr |
|
import json |
|
import librosa |
|
import os |
|
import soundfile as sf |
|
import tempfile |
|
import uuid |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
from nemo.collections.asr.models import ASRModel |
|
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED |
|
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED |
|
from transformers import VitsTokenizer, VitsModel, set_seed |
|
import scipy.io.wavfile as wav |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
|
|
|
|
asr_model = ASRModel.from_pretrained("nvidia/canary-1b") |
|
asr_model.eval() |
|
asr_model.change_decoding_strategy(None) |
|
decoding_cfg = asr_model.cfg.decoding |
|
decoding_cfg.beam.beam_size = 1 |
|
asr_model.change_decoding_strategy(decoding_cfg) |
|
asr_model.cfg.preprocessor.dither = 0.0 |
|
asr_model.cfg.preprocessor.pad_to = 0 |
|
feature_stride = asr_model.cfg.preprocessor['window_stride'] |
|
model_stride_in_secs = feature_stride * 8 |
|
frame_asr = FrameBatchMultiTaskAED( |
|
asr_model=asr_model, |
|
frame_len=40.0, |
|
total_buffer=40.0, |
|
batch_size=16, |
|
) |
|
|
|
|
|
torch.random.manual_seed(0) |
|
llm_model = AutoModelForCausalLM.from_pretrained( |
|
"microsoft/Phi-3-mini-128k-instruct", |
|
device_map="auto", |
|
torch_dtype="auto", |
|
trust_remote_code=True, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct") |
|
pipe = pipeline("text-generation", model=llm_model, tokenizer=tokenizer) |
|
|
|
|
|
tts_tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng") |
|
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") |
|
|
|
|
|
def transcribe(audio_filepath): |
|
if audio_filepath is None: |
|
raise gr.Error("Please provide some input audio.") |
|
|
|
utt_id = uuid.uuid4() |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
|
data, sr = librosa.load(audio_filepath, sr=None, mono=True) |
|
if sr != SAMPLE_RATE: |
|
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) |
|
converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav") |
|
sf.write(converted_audio_filepath, data, SAMPLE_RATE) |
|
|
|
|
|
duration = len(data) / SAMPLE_RATE |
|
manifest_data = { |
|
"audio_filepath": converted_audio_filepath, |
|
"source_lang": "en", |
|
"target_lang": "en", |
|
"taskname": "asr", |
|
"pnc": "no", |
|
"answer": "predict", |
|
"duration": str(duration), |
|
} |
|
manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json") |
|
with open(manifest_filepath, 'w') as fout: |
|
fout.write(json.dumps(manifest_data)) |
|
|
|
if duration < 40: |
|
transcription = asr_model.transcribe(manifest_filepath)[0] |
|
else: |
|
transcription = get_buffered_pred_feat_multitaskAED( |
|
frame_asr, |
|
asr_model.cfg.preprocessor, |
|
model_stride_in_secs, |
|
asr_model.device, |
|
manifest=manifest_filepath, |
|
)[0].text |
|
|
|
return transcription |
|
|
|
|
|
def generate_text(input_text): |
|
messages=input_text |
|
generation_args = { |
|
"max_new_tokens": 200, |
|
"return_full_text": True, |
|
"temperature": 0.0, |
|
"do_sample": False, |
|
} |
|
generated_text = pipe(messages, **generation_args)[0]["generated_text"] |
|
return generated_text |
|
|
|
|
|
def gen_speech(text): |
|
set_seed(555) |
|
input_text = tts_tokenizer(text, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = tts_model(**input_text) |
|
waveform_np = outputs.waveform[0].cpu().numpy() |
|
output_file = f"{str(uuid.uuid4())}.wav" |
|
wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np) |
|
return output_file |
|
|
|
|
|
def process_audio(audio_filepath): |
|
transcription = transcribe(audio_filepath) |
|
print("Done transcribing") |
|
generated_text = generate_text(transcription) |
|
print("Done generating") |
|
audio_output_filepath = gen_speech(generated_text) |
|
print("Done speaking") |
|
return transcription, generated_text, audio_output_filepath |
|
|
|
|
|
gr.Interface( |
|
fn=process_audio, |
|
inputs=[gr.Audio(sources=["microphone"], type="filepath", label="Input Audio")], |
|
outputs=[ |
|
gr.Textbox(label="Transcription"), |
|
gr.Textbox(label="Generated Text"), |
|
gr.Audio(type="filepath", label="Generated Speech") |
|
], |
|
title="YOUR AWESOME AI ASSISTANT", |
|
description="Gets input audio from user, transcribe it with ASR Canary1b, generate text with Phi3LLM, and convert it back to speech with VITS TTS." |
|
).launch(inbrowser=True) |