|
|
|
import gradio as gr |
|
import io |
|
import os |
|
import torch |
|
from parler_tts import ParlerTTSForConditionalGeneration |
|
from transformers import AutoTokenizer, AutoModel |
|
import numpy as np |
|
import google.generativeai as genai |
|
import asyncio |
|
import librosa |
|
import torchaudio |
|
|
|
|
|
ASR_MODEL_NAME = "ai4bharat/indic-conformer-600m-multilingual" |
|
TARGET_SAMPLE_RATE = 16000 |
|
|
|
TTS_MODEL_NAME = "ai4bharat/indic-parler-tts" |
|
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyD6x3Yoby4eQ6QL2kaaG_Rz3fG3rh7wPB8") |
|
GEMINI_MODEL_NAME_GRADIO = "gemini-1.5-flash-latest" |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
asr_model_gradio = None |
|
|
|
gemini_model_instance_gradio = None |
|
tts_model_gradio = None |
|
tts_tokenizer_gradio = None |
|
|
|
|
|
def load_all_resources_gradio(): |
|
global asr_model_gradio, tts_model_gradio, tts_tokenizer_gradio, gemini_model_instance_gradio |
|
print(f"Gradio: Loading resources. ASR will be on device: {DEVICE}") |
|
|
|
if asr_model_gradio is None: |
|
print(f"Gradio: Loading ASR model: {ASR_MODEL_NAME} using AutoModel") |
|
try: |
|
|
|
asr_model_gradio = AutoModel.from_pretrained(ASR_MODEL_NAME, trust_remote_code=True) |
|
asr_model_gradio.to(DEVICE) |
|
|
|
|
|
if DEVICE == "cuda" and hasattr(asr_model_gradio, 'half'): |
|
print("Gradio: Applying .half() to ASR model.") |
|
asr_model_gradio.half() |
|
asr_model_gradio.eval() |
|
print(f"Gradio: ASR model ({ASR_MODEL_NAME}) loaded using AutoModel.") |
|
except Exception as e: |
|
print(f"Gradio: Failed to load ASR model {ASR_MODEL_NAME} using AutoModel: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
asr_model_gradio = None |
|
|
|
if tts_model_gradio is None: |
|
print(f"Gradio: Loading IndicParler-TTS model: {TTS_MODEL_NAME}") |
|
|
|
|
|
tts_parler_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True) |
|
tts_model_gradio = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True).to(DEVICE) |
|
tts_tokenizer_gradio = tts_parler_tokenizer |
|
print("Gradio: IndicParler-TTS model loaded.") |
|
|
|
if gemini_model_instance_gradio is None: |
|
if not GEMINI_API_KEY: |
|
print("Gradio: GEMINI_API_KEY not found. LLM functionality via Gemini will be limited.") |
|
else: |
|
try: |
|
genai.configure(api_key=GEMINI_API_KEY) |
|
gemini_model_instance_gradio = genai.GenerativeModel(GEMINI_MODEL_NAME_GRADIO) |
|
print(f"Gradio: Gemini API configured with model: {GEMINI_MODEL_NAME_GRADIO}") |
|
except Exception as e: |
|
print(f"Gradio: Failed to configure Gemini API: {e}") |
|
gemini_model_instance_gradio = None |
|
|
|
print("Gradio: All resources loaded (or attempted).") |
|
|
|
|
|
|
|
def transcribe_audio_gradio(audio_input_tuple): |
|
if asr_model_gradio is None: |
|
return f"Error: ASR model ({ASR_MODEL_NAME}) not loaded." |
|
|
|
if audio_input_tuple is None: |
|
print("Gradio: No audio provided to transcribe_audio_gradio.") |
|
return "No audio provided." |
|
|
|
sample_rate, audio_numpy = audio_input_tuple |
|
|
|
if audio_numpy is None or audio_numpy.size == 0: |
|
print("Gradio: Audio numpy array is empty.") |
|
return "Empty audio received." |
|
|
|
|
|
if audio_numpy.ndim > 1: |
|
if audio_numpy.shape[0] == 2 and audio_numpy.ndim == 2: |
|
audio_numpy = librosa.to_mono(audio_numpy) |
|
elif audio_numpy.shape[1] == 2 and audio_numpy.ndim == 2: |
|
audio_numpy = np.mean(audio_numpy, axis=1) |
|
|
|
if audio_numpy.dtype != np.float32: |
|
if np.issubdtype(audio_numpy.dtype, np.integer): |
|
audio_numpy = audio_numpy.astype(np.float32) / np.iinfo(audio_numpy.dtype).max |
|
else: |
|
audio_numpy = audio_numpy.astype(np.float32) |
|
|
|
|
|
if sample_rate != TARGET_SAMPLE_RATE: |
|
print(f"Gradio: Resampling audio from {sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz.") |
|
try: |
|
audio_numpy = librosa.resample(y=audio_numpy, orig_sr=sample_rate, target_sr=TARGET_SAMPLE_RATE) |
|
|
|
except Exception as e: |
|
print(f"Gradio: Error during resampling: {e}") |
|
return f"Error during audio resampling: {str(e)}" |
|
|
|
try: |
|
print(f"Gradio: Preparing to transcribe with {ASR_MODEL_NAME}. Input audio shape: {audio_numpy.shape}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if audio_numpy.ndim > 1: |
|
audio_numpy = audio_numpy.squeeze() |
|
if audio_numpy.ndim > 1 : |
|
print(f"Gradio: Audio numpy array for ASR has unexpected dimensions after processing: {audio_numpy.shape}") |
|
return "Error: Audio processing resulted in unexpected dimensions." |
|
|
|
wav_tensor = torch.from_numpy(audio_numpy).to(DEVICE) |
|
|
|
if wav_tensor.ndim == 1: |
|
wav_tensor = wav_tensor.unsqueeze(0) |
|
|
|
print(f"Gradio: Transcribing with {ASR_MODEL_NAME} using CTC. Input tensor shape: {wav_tensor.shape}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
transcription_result = asr_model_gradio(wav_tensor, "hi", "ctc") |
|
|
|
|
|
|
|
if isinstance(transcription_result, list) and len(transcription_result) > 0: |
|
transcribed_text = transcription_result[0] |
|
elif isinstance(transcription_result, str): |
|
transcribed_text = transcription_result |
|
else: |
|
print(f"Gradio: Unexpected ASR result format: {type(transcription_result)}, value: {transcription_result}") |
|
transcribed_text = "ASR result format not recognized." |
|
|
|
transcribed_text = transcribed_text.strip() |
|
print(f"Gradio: Transcription ({ASR_MODEL_NAME}, CTC): {transcribed_text}") |
|
return transcribed_text if transcribed_text else "Transcription resulted in empty text." |
|
except Exception as e: |
|
print(f"Gradio: Error during {ASR_MODEL_NAME} transcription (AutoModel callable): {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return f"Error during transcription ({ASR_MODEL_NAME}): {str(e)}" |
|
|
|
|
|
|
|
def generate_gemini_response_gradio(text_input: str): |
|
if not gemini_model_instance_gradio: |
|
return "Error: Gemini LLM not configured or API key missing." |
|
if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "No audio provided" in text_input or "Transcription resulted in empty text" in text_input or "Empty audio received" in text_input or "ASR result format not recognized" in text_input: |
|
print(f"Gradio: Invalid input to Gemini: '{text_input}'. Skipping LLM response.") |
|
return "LLM (Gemini) skipped due to transcription issue or no input." |
|
try: |
|
print(f"Gradio: Sending to Gemini: '{text_input}'") |
|
full_prompt = f"User: {text_input}\nAssistant:" |
|
response = gemini_model_instance_gradio.generate_content(full_prompt) |
|
response_text = "" |
|
if response.candidates and response.candidates[0].content.parts: |
|
response_text = response.candidates[0].content.parts[0].text.strip() |
|
else: |
|
feedback_info = "" |
|
if hasattr(response, 'prompt_feedback') and response.prompt_feedback: |
|
feedback_info = f" Feedback: {response.prompt_feedback}" |
|
print(f"Gradio: Gemini response did not contain expected content.{feedback_info}") |
|
response_text = f"I'm sorry, I couldn't generate a response for that (Gemini).{feedback_info}" |
|
|
|
print(f"Gradio: Gemini LLM Response: {response_text}") |
|
return response_text if response_text else "Gemini LLM generated an empty response." |
|
except Exception as e: |
|
print(f"Gradio: Error during Gemini LLM generation: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return f"Error during Gemini LLM generation: {str(e)}" |
|
|
|
def synthesize_speech_gradio(text_input: str, description: str = "A clear, female voice speaking in English."): |
|
if tts_model_gradio is None or tts_tokenizer_gradio is None: |
|
return "Error: TTS model or its tokenizer not loaded." |
|
if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "LLM skipped" in text_input or "generated an empty response" in text_input or "not configured" in text_input or "ASR result format not recognized" in text_input : |
|
print(f"Gradio: Invalid input to TTS: '{text_input}'. Skipping synthesis.") |
|
return "TTS skipped due to LLM issue or no input." |
|
try: |
|
print(f"Gradio: Synthesizing speech for: '{text_input}'") |
|
description_tokenized = tts_tokenizer_gradio(description, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
description_ids = description_tokenized.input_ids.to(DEVICE) |
|
description_attention_mask = description_tokenized.attention_mask.to(DEVICE) |
|
|
|
prompt_tokenized = tts_tokenizer_gradio(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
prompt_ids = prompt_tokenized.input_ids.to(DEVICE) |
|
|
|
if prompt_ids.shape[-1] == 0: |
|
print(f"Gradio: Tokenized prompt for TTS is empty. Text was: '{text_input}'. Skipping synthesis.") |
|
return "TTS skipped: Input text resulted in empty tokens." |
|
|
|
|
|
generation = tts_model_gradio.generate( |
|
input_ids=description_ids, |
|
attention_mask=description_attention_mask, |
|
prompt_input_ids=prompt_ids, |
|
do_sample=True, temperature=0.7, top_k=50, top_p=0.95 |
|
).cpu().numpy().squeeze() |
|
|
|
sampling_rate = tts_model_gradio.config.sampling_rate |
|
print(f"Gradio: Speech synthesized. Array shape: {generation.shape}, Sample rate: {sampling_rate}") |
|
return (sampling_rate, generation) |
|
except Exception as e: |
|
print(f"Gradio: Error during speech synthesis: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
if "You need to specify either `text` or `text_target`" in str(e): |
|
return "Error in TTS: Model requires 'text' or 'text_target'. Input might be too short or problematic." |
|
return f"Error during speech synthesis: {str(e)}" |
|
|
|
|
|
load_all_resources_gradio() |
|
|
|
def full_pipeline_gradio(audio_input): |
|
transcribed_text_output = transcribe_audio_gradio(audio_input) |
|
print(f"DEBUG full_pipeline_gradio - Step 1 (Transcription): '{transcribed_text_output}' (type: {type(transcribed_text_output)})") |
|
llm_response_text_output = generate_gemini_response_gradio(transcribed_text_output) |
|
print(f"DEBUG full_pipeline_gradio - Step 2 (LLM Response): '{llm_response_text_output}' (type: {type(llm_response_text_output)})") |
|
tts_synthesis_result = synthesize_speech_gradio(llm_response_text_output) |
|
final_audio_output = None |
|
if isinstance(tts_synthesis_result, tuple) and len(tts_synthesis_result) == 2 and isinstance(tts_synthesis_result[1], np.ndarray): |
|
final_audio_output = tts_synthesis_result |
|
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Success): Audio tuple with shape {tts_synthesis_result[1].shape if isinstance(tts_synthesis_result[1], np.ndarray) else 'N/A'}") |
|
else: |
|
error_message_from_tts = str(tts_synthesis_result) if isinstance(tts_synthesis_result, str) else "TTS synthesis failed or returned unexpected type" |
|
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Failed/Non-audio): {error_message_from_tts}. Providing silent audio.") |
|
|
|
if llm_response_text_output and not llm_response_text_output.startswith("Error:") and "LLM skipped" not in llm_response_text_output and "ASR result format not recognized" not in llm_response_text_output: |
|
llm_response_text_output = f"{llm_response_text_output} | (TTS Problem: {error_message_from_tts})" |
|
elif not llm_response_text_output or llm_response_text_output.startswith("Error:") or "LLM skipped" in llm_response_text_output or "ASR result format not recognized" in llm_response_text_output: |
|
|
|
llm_response_text_output = f"{llm_response_text_output} (TTS also had an issue: {error_message_from_tts})" |
|
|
|
default_sample_rate = tts_model_gradio.config.sampling_rate if tts_model_gradio and hasattr(tts_model_gradio, 'config') else TARGET_SAMPLE_RATE |
|
final_audio_output = (default_sample_rate, np.array([0.0], dtype=np.float32)) |
|
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Fallback): Silent audio tuple") |
|
print(f"DEBUG full_pipeline_gradio - RETURNING: Transcription='{transcribed_text_output}', LLM_Text='{llm_response_text_output}', Audio_Type={type(final_audio_output)}") |
|
return transcribed_text_output, llm_response_text_output, final_audio_output |
|
|
|
with gr.Blocks(title="Conversational AI Demo") as demo: |
|
gr.Markdown("# Conversational AI Demo (STT -> Gemini LLM -> TTS)") |
|
with gr.Row(): |
|
audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Speak Here") |
|
process_button = gr.Button("Process Audio") |
|
with gr.Accordion("Outputs", open=True): |
|
transcription_out = gr.Textbox(label="You Said (Transcription)", lines=2) |
|
llm_response_out = gr.Textbox(label="Gemini Assistant Says (Text)", lines=5) |
|
audio_out = gr.Audio(label="Assistant Says (Audio)") |
|
|
|
process_button.click( |
|
fn=full_pipeline_gradio, |
|
inputs=[audio_in], |
|
outputs=[transcription_out, llm_response_out, audio_out] |
|
) |
|
gr.Markdown("---") |
|
gr.Markdown("### How to Use:") |
|
gr.Markdown("1. Ensure your `GEMINI_API_KEY` environment variable is set.") |
|
gr.Markdown("2. Click into the 'Speak Here' box and record your audio.") |
|
gr.Markdown("3. Click the 'Process Audio' button.") |
|
gr.Markdown("4. View the transcription, Gemini's text response, and listen to the audio response.") |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=False) |