File size: 4,868 Bytes
a8286dc 604efba 0deb309 604efba a8286dc 604efba 6160888 604efba 6160888 604efba 6160888 046c2b1 6160888 604efba 6160888 927a24e 6160888 604efba 6160888 a8286dc 0deb309 6160888 604efba 6160888 604efba 6160888 c5a564e 6160888 8eaeb08 6160888 b35863e 6160888 a031984 6160888 a031984 6160888 dd54457 6160888 dd54457 6160888 dd54457 6160888 927a24e 6160888 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
import json
import librosa
import os
import soundfile as sf
import tempfile
import uuid
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
from transformers import VitsTokenizer, VitsModel, set_seed
import scipy.io.wavfile as wav
# Constants
SAMPLE_RATE = 16000 # Hz
# Load ASR model
asr_model = ASRModel.from_pretrained("nvidia/canary-1b")
asr_model.eval()
asr_model.change_decoding_strategy(None)
decoding_cfg = asr_model.cfg.decoding
decoding_cfg.beam.beam_size = 1
asr_model.change_decoding_strategy(decoding_cfg)
asr_model.cfg.preprocessor.dither = 0.0
asr_model.cfg.preprocessor.pad_to = 0
feature_stride = asr_model.cfg.preprocessor['window_stride']
model_stride_in_secs = feature_stride * 8
frame_asr = FrameBatchMultiTaskAED(
asr_model=asr_model,
frame_len=40.0,
total_buffer=40.0,
batch_size=16,
)
# Load LLM model
torch.random.manual_seed(0)
llm_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-128k-instruct",
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
pipe = pipeline("text-generation", model=llm_model, tokenizer=tokenizer)
# Load TTS model
tts_tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
# Function to convert audio to text using ASR
def transcribe(audio_filepath):
if audio_filepath is None:
raise gr.Error("Please provide some input audio.")
utt_id = uuid.uuid4()
with tempfile.TemporaryDirectory() as tmpdir:
# Convert to 16 kHz
data, sr = librosa.load(audio_filepath, sr=None, mono=True)
if sr != SAMPLE_RATE:
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav")
sf.write(converted_audio_filepath, data, SAMPLE_RATE)
# Transcribe audio
duration = len(data) / SAMPLE_RATE
manifest_data = {
"audio_filepath": converted_audio_filepath,
"source_lang": "en",
"target_lang": "en",
"taskname": "asr",
"pnc": "no",
"answer": "predict",
"duration": str(duration),
}
manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json")
with open(manifest_filepath, 'w') as fout:
fout.write(json.dumps(manifest_data))
if duration < 40:
transcription = asr_model.transcribe(manifest_filepath)[0]
else:
transcription = get_buffered_pred_feat_multitaskAED(
frame_asr,
asr_model.cfg.preprocessor,
model_stride_in_secs,
asr_model.device,
manifest=manifest_filepath,
)[0].text
return transcription
# Function to generate text using LLM
def generate_text(input_text):
messages=input_text
generation_args = {
"max_new_tokens": 200,
"return_full_text": True,
"temperature": 0.0,
"do_sample": False,
}
generated_text = pipe(messages, **generation_args)[0]["generated_text"]
return generated_text
# Function to convert text to speech using TTS
def gen_speech(text):
set_seed(555) # Make it deterministic
input_text = tts_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = tts_model(**input_text)
waveform_np = outputs.waveform[0].cpu().numpy()
output_file = f"{str(uuid.uuid4())}.wav"
wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
return output_file
# Combined function for Gradio interface
def process_audio(audio_filepath):
transcription = transcribe(audio_filepath)
print("Done transcribing")
generated_text = generate_text(transcription)
print("Done generating")
audio_output_filepath = gen_speech(generated_text)
print("Done speaking")
return transcription, generated_text, audio_output_filepath
# Create Gradio interface
gr.Interface(
fn=process_audio,
inputs=[gr.Audio(sources=["microphone"], type="filepath", label="Input Audio")],
outputs=[
gr.Textbox(label="Transcription"),
gr.Textbox(label="Generated Text"),
gr.Audio(type="filepath", label="Generated Speech")
],
title="YOUR AWESOME AI ASSISTANT",
description="Gets input audio from user, transcribe it with ASR Canary1b, generate text with Phi3LLM, and convert it back to speech with VITS TTS."
).launch(inbrowser=True) |