vicuna-clip / app.py
ford442's picture
Update app.py
8213d9e verified
raw
history blame
4.64 kB
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import soundfile as sf
import numpy as np
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
import IPython.display as ipd # We still need this if running in a notebook
# --- Whisper (ASR) Setup ---
ASR_MODEL_NAME = "openai/whisper-large-v2"
asr_device = "cuda" if torch.cuda.is_available() else "cpu"
asr_pipe = pipeline(
task="automatic-speech-recognition",
model=ASR_MODEL_NAME,
chunk_length_s=30,
device=asr_device,
)
all_special_ids = asr_pipe.tokenizer.all_special_ids
transcribe_token_id = all_special_ids[-5]
translate_token_id = all_special_ids[-6]
# --- FastSpeech2 (TTS) Setup - Using fairseq ---
TTS_MODEL_NAME = "facebook/fastspeech2-en-ljspeech"
tts_device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the fairseq model, config, and task.
tts_models, tts_cfg, tts_task = load_model_ensemble_and_task_from_hf_hub(
TTS_MODEL_NAME,
arg_overrides={"vocoder": "hifigan", "fp16": False}
)
tts_model = tts_models[0]
TTSHubInterface.update_cfg_with_data_cfg(tts_cfg, tts_task.data_cfg)
tts_generator = tts_task.build_generator(tts_model, tts_cfg)
# Move the fairseq model to the correct device.
tts_model.to(tts_device)
tts_model.eval() # Put the model in evaluation mode
# --- Vicuna (LLM) Setup ---
VICUNA_MODEL_NAME = "lmsys/vicuna-7b-v1.5" # Or your preferred Vicuna
vicuna_device = "cuda" if torch.cuda.is_available() else "cpu"
vicuna_tokenizer = AutoTokenizer.from_pretrained(VICUNA_MODEL_NAME)
vicuna_model = AutoModelForCausalLM.from_pretrained(
VICUNA_MODEL_NAME,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
)
# --- ASR Function ---
def transcribe_audio(microphone, state, task="transcribe"):
if microphone is None:
return state, state
asr_pipe.model.config.forced_decoder_ids = [
[2, transcribe_token_id if task == "transcribe" else translate_token_id]
]
text = asr_pipe(microphone)["text"]
# --- VICUNA INTEGRATION ---
system_prompt = """You are a friendly and enthusiastic tutor for young children (ages 6-9).
You answer questions clearly and simply, using age-appropriate language.
You are also a little bit silly and like to make jokes."""
prompt = f"{system_prompt}\nUser: {text}"
with torch.no_grad():
vicuna_input = vicuna_tokenizer(prompt, return_tensors="pt").to(vicuna_device)
vicuna_output = vicuna_model.generate(**vicuna_input, max_new_tokens=128)
vicuna_response = vicuna_tokenizer.decode(vicuna_output[0], skip_special_tokens=True)
vicuna_response = vicuna_response.replace(prompt, "").strip()
updated_state = state + "\n" + vicuna_response
return updated_state, updated_state
# --- TTS Function (Modified for fairseq) ---
def synthesize_speech(text):
try:
sample = TTSHubInterface.get_model_input(tts_task, text)
# Move input tensors to the correct device
if torch.cuda.is_available():
sample['net_input'] = {k: v.cuda() for k, v in sample['net_input'].items()}
else:
sample['net_input'] = {k: v.cpu() for k, v in sample['net_input'].items()}
wav, rate = TTSHubInterface.get_prediction(tts_task, tts_model, tts_generator, sample)
wav_numpy = wav.cpu().numpy() # fairseq returns a tensor, not a numpy array
return (rate, wav_numpy) # Return rate and NumPy array
except Exception as e:
print(e)
return (None, None)
# --- Gradio Interface ---
with gr.Blocks(title="Whisper, Vicuna, & FastSpeech2 Demo") as demo:
gr.Markdown("# Speech-to-Text-to-Speech Demo with Vicuna")
gr.Markdown("Speak into your microphone, get a transcription, Vicuna will process it, and then you'll hear the result!")
with gr.Tab("Transcribe & Synthesize"):
mic_input = gr.Audio(source="microphone", type="filepath", optional=True, label="Speak Here")
transcription_output = gr.Textbox(lines=5, label="Transcription and Vicuna Response")
audio_output = gr.Audio(label="Synthesized Speech", type="numpy")
transcription_state = gr.State(value="")
mic_input.change(
fn=transcribe_audio,
inputs=[mic_input, transcription_state],
outputs=[transcription_output, transcription_state]
).then(
fn=synthesize_speech,
inputs=transcription_output,
outputs=audio_output
)
demo.launch(enable_queue=True, share=False)