Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import soundfile as sf | |
import numpy as np | |
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub | |
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface | |
import IPython.display as ipd # We still need this if running in a notebook | |
# --- Whisper (ASR) Setup --- | |
ASR_MODEL_NAME = "openai/whisper-large-v2" | |
asr_device = "cuda" if torch.cuda.is_available() else "cpu" | |
asr_pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=ASR_MODEL_NAME, | |
chunk_length_s=30, | |
device=asr_device, | |
) | |
all_special_ids = asr_pipe.tokenizer.all_special_ids | |
transcribe_token_id = all_special_ids[-5] | |
translate_token_id = all_special_ids[-6] | |
# --- FastSpeech2 (TTS) Setup - Using fairseq --- | |
TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech" | |
tts_device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the fairseq model, config, and task. | |
tts_models, tts_cfg, tts_task = load_model_ensemble_and_task_from_hf_hub( | |
TTS_MODEL_NAME, | |
arg_overrides={"vocoder": "hifigan", "fp16": False} | |
) | |
tts_model = tts_models[0] | |
TTSHubInterface.update_cfg_with_data_cfg(tts_cfg, tts_task.data_cfg) | |
tts_generator = tts_task.build_generator(tts_model, tts_cfg) | |
# Move the fairseq model to the correct device. | |
tts_model.to(tts_device) | |
tts_model.eval() # Put the model in evaluation mode | |
# --- Vicuna (LLM) Setup --- | |
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Or your preferred Vicuna | |
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu" | |
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME) | |
vicuna_model = AutoModelForCausalLM.from_pretrained( | |
VICUNA_MODEL_NAME, | |
load_in_8bit=True, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
# --- ASR Function --- | |
def transcribe_audio(microphone, state, task="transcribe"): | |
if microphone is None: | |
return state, state | |
asr_pipe.model.config.forced_decoder_ids = [ | |
[2, transcribe_token_id if task == "transcribe" else translate_token_id] | |
] | |
text = asr_pipe(microphone)["text"] | |
# --- VICUNA INTEGRATION --- | |
system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9). | |
You answer questions clearly and simply, using age-appropriate language. | |
You are also a little bit silly and like to make jokes.""" | |
prompt = f"{system_prompt}\nUser: {text}" | |
with torch.no_grad(): | |
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device) | |
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128) | |
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True) | |
vicuna_response = vicuna_response.replace(prompt, "").strip() | |
updated_state = state + "\n" + vicuna_response | |
return updated_state, updated_state | |
# --- TTS Function (Modified for fairseq) --- | |
def synthesize_speech(text): | |
try: | |
sample = TTSHubInterface.get_model_input(tts_task, text) | |
# Move input tensors to the correct device | |
if torch.cuda.is_available(): | |
sample['net_input'] = {k: v.cuda() for k, v in sample['net_input'].items()} | |
else: | |
sample['net_input'] = {k: v.cpu() for k, v in sample['net_input'].items()} | |
wav, rate = TTSHubInterface.get_prediction(tts_task, tts_model, tts_generator, sample) | |
wav_numpy = wav.cpu().numpy() # fairseq returns a tensor, not a numpy array | |
return (rate, wav_numpy) # Return rate and NumPy array | |
except Exception as e: | |
print(e) | |
return (None, None) | |
# --- Gradio Interface --- | |
with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo: | |
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna") | |
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!") | |
with gr.Tab("Transcribe & Synthesize"): | |
mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here") | |
transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response") | |
audio_output = gr.Audio(label="Synthesized Speech", type="numpy") | |
transcription_state = gr.State(value="") | |
mic_input.change( | |
fn=transcribe_audio, | |
inputs=[mic_input, transcription_state], | |
outputs=[transcription_output, transcription_state] | |
).then( | |
fn=synthesize_speech, | |
inputs=transcription_output, | |
outputs=audio_output | |
) | |
demo.launch(enable_queue=True, share=False) |