File size: 17,105 Bytes
c2ac364 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 |
# gradio_app.py
import gradio as gr
import io
import os
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoModel # CHANGED: Using AutoModel as per model card
import numpy as np
import google.generativeai as genai
import asyncio
import librosa
import torchaudio # Often used by models like this for audio loading/processing internally or as input type
# --- Configuration ---
ASR_MODEL_NAME = "ai4bharat/indic-conformer-600m-multilingual"
TARGET_SAMPLE_RATE = 16000 # Model expects 16kHz
TTS_MODEL_NAME = "ai4bharat/indic-parler-tts"
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyD6x3Yoby4eQ6QL2kaaG_Rz3fG3rh7wPB8")
GEMINI_MODEL_NAME_GRADIO = "gemini-1.5-flash-latest"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# torch_dtype for ParlerTTS, Gemini etc. For ASR model, it might handle its own precision.
# --- Global Model Variables ---
asr_model_gradio = None # This will be the AutoModel instance
gemini_model_instance_gradio = None
tts_model_gradio = None
tts_tokenizer_gradio = None # For ParlerTTS
# --- Model Loading & API Configuration ---
def load_all_resources_gradio():
global asr_model_gradio, tts_model_gradio, tts_tokenizer_gradio, gemini_model_instance_gradio
print(f"Gradio: Loading resources. ASR will be on device: {DEVICE}")
if asr_model_gradio is None:
print(f"Gradio: Loading ASR model: {ASR_MODEL_NAME} using AutoModel")
try:
# Load using AutoModel as per the model card's implication
asr_model_gradio = AutoModel.from_pretrained(ASR_MODEL_NAME, trust_remote_code=True)
asr_model_gradio.to(DEVICE) # Move model to device
# The model might handle its own precision (e.g. .half()) internally if `trust_remote_code` allows
# Or you might need to call asr_model_gradio.half() if it supports it and you're on CUDA.
if DEVICE == "cuda" and hasattr(asr_model_gradio, 'half'):
print("Gradio: Applying .half() to ASR model.")
asr_model_gradio.half()
asr_model_gradio.eval()
print(f"Gradio: ASR model ({ASR_MODEL_NAME}) loaded using AutoModel.")
except Exception as e:
print(f"Gradio: Failed to load ASR model {ASR_MODEL_NAME} using AutoModel: {e}")
import traceback
traceback.print_exc()
asr_model_gradio = None
if tts_model_gradio is None: # ParlerTTS loading
print(f"Gradio: Loading IndicParler-TTS model: {TTS_MODEL_NAME}")
# Ensure ParlerTTS specific tokenizer is loaded for TTS
# Note: ASR model might have its own internal tokenizer/processor handled by its custom code
tts_parler_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
tts_model_gradio = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True).to(DEVICE)
tts_tokenizer_gradio = tts_parler_tokenizer
print("Gradio: IndicParler-TTS model loaded.")
if gemini_model_instance_gradio is None: # Gemini loading
if not GEMINI_API_KEY:
print("Gradio: GEMINI_API_KEY not found. LLM functionality via Gemini will be limited.")
else:
try:
genai.configure(api_key=GEMINI_API_KEY)
gemini_model_instance_gradio = genai.GenerativeModel(GEMINI_MODEL_NAME_GRADIO)
print(f"Gradio: Gemini API configured with model: {GEMINI_MODEL_NAME_GRADIO}")
except Exception as e:
print(f"Gradio: Failed to configure Gemini API: {e}")
gemini_model_instance_gradio = None
print("Gradio: All resources loaded (or attempted).")
# --- Helper Functions ---
def transcribe_audio_gradio(audio_input_tuple):
if asr_model_gradio is None:
return f"Error: ASR model ({ASR_MODEL_NAME}) not loaded."
if audio_input_tuple is None:
print("Gradio: No audio provided to transcribe_audio_gradio.")
return "No audio provided."
sample_rate, audio_numpy = audio_input_tuple
if audio_numpy is None or audio_numpy.size == 0:
print("Gradio: Audio numpy array is empty.")
return "Empty audio received."
# Ensure audio is mono float32, which is a common expectation
if audio_numpy.ndim > 1:
if audio_numpy.shape[0] == 2 and audio_numpy.ndim == 2:
audio_numpy = librosa.to_mono(audio_numpy)
elif audio_numpy.shape[1] == 2 and audio_numpy.ndim == 2:
audio_numpy = np.mean(audio_numpy, axis=1)
if audio_numpy.dtype != np.float32:
if np.issubdtype(audio_numpy.dtype, np.integer):
audio_numpy = audio_numpy.astype(np.float32) / np.iinfo(audio_numpy.dtype).max
else:
audio_numpy = audio_numpy.astype(np.float32)
# Resample to TARGET_SAMPLE_RATE (16kHz)
if sample_rate != TARGET_SAMPLE_RATE:
print(f"Gradio: Resampling audio from {sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz.")
try:
audio_numpy = librosa.resample(y=audio_numpy, orig_sr=sample_rate, target_sr=TARGET_SAMPLE_RATE)
# After resampling, the audio_numpy is at TARGET_SAMPLE_RATE
except Exception as e:
print(f"Gradio: Error during resampling: {e}")
return f"Error during audio resampling: {str(e)}"
try:
print(f"Gradio: Preparing to transcribe with {ASR_MODEL_NAME}. Input audio shape: {audio_numpy.shape}")
# The model card example `model(wav, "hi", "ctc")` implies it might take a waveform tensor.
# We have a numpy array. We need to convert it to a PyTorch tensor.
# The model card uses torchaudio.load which returns a tensor.
# Let's convert our numpy array to a tensor and ensure it's on the correct device.
# Ensure the audio_numpy is 1D as expected by many ASR models for a single channel
if audio_numpy.ndim > 1:
audio_numpy = audio_numpy.squeeze() # Attempt to remove singleton dimensions
if audio_numpy.ndim > 1 : # If still more than 1D, problem
print(f"Gradio: Audio numpy array for ASR has unexpected dimensions after processing: {audio_numpy.shape}")
return "Error: Audio processing resulted in unexpected dimensions."
wav_tensor = torch.from_numpy(audio_numpy).to(DEVICE)
# The model might expect a batch dimension, e.g., [1, num_samples]
if wav_tensor.ndim == 1:
wav_tensor = wav_tensor.unsqueeze(0)
print(f"Gradio: Transcribing with {ASR_MODEL_NAME} using CTC. Input tensor shape: {wav_tensor.shape}")
# Perform ASR with CTC decoding (you can choose "rnnt" if preferred and supported)
# The language code "hi" is for Hindi. You might want to make this configurable
# or see if the model supports language auto-detection if you pass None or omit it.
# For now, assuming "hi" or that the model handles mixed language if lang_id is not strictly enforced.
# The model card doesn't specify if language ID is optional or how auto-detection works.
# Let's try "auto" or a common language like "en" or "hi" to start.
# The model card indicates training on 22 languages, so it's multilingual.
# If language_id is required, you'll need to provide it.
# Let's assume for now we try with a common Indian language or let the model try to auto-detect if "auto" or None is valid.
# The snippet "model(wav, "hi", "ctc")" is specific.
# The `model()` call is synchronous. Gradio handles this in a thread.
with torch.no_grad(): # Good practice for inference
transcription_result = asr_model_gradio(wav_tensor, "hi", "ctc") # Using lang_id="hi" and strategy="ctc" as per example
# The output format needs to be checked. The model card implies it's the transcribed string directly.
# It might be a list of transcriptions if batching occurs, or a dict.
if isinstance(transcription_result, list) and len(transcription_result) > 0:
transcribed_text = transcription_result[0] # Assuming first result for non-batched input
elif isinstance(transcription_result, str):
transcribed_text = transcription_result
else:
print(f"Gradio: Unexpected ASR result format: {type(transcription_result)}, value: {transcription_result}")
transcribed_text = "ASR result format not recognized."
transcribed_text = transcribed_text.strip()
print(f"Gradio: Transcription ({ASR_MODEL_NAME}, CTC): {transcribed_text}")
return transcribed_text if transcribed_text else "Transcription resulted in empty text."
except Exception as e:
print(f"Gradio: Error during {ASR_MODEL_NAME} transcription (AutoModel callable): {e}")
import traceback
traceback.print_exc()
return f"Error during transcription ({ASR_MODEL_NAME}): {str(e)}"
# ... (Gemini LLM and TTS functions remain the same) ...
def generate_gemini_response_gradio(text_input: str):
if not gemini_model_instance_gradio:
return "Error: Gemini LLM not configured or API key missing."
if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "No audio provided" in text_input or "Transcription resulted in empty text" in text_input or "Empty audio received" in text_input or "ASR result format not recognized" in text_input:
print(f"Gradio: Invalid input to Gemini: '{text_input}'. Skipping LLM response.")
return "LLM (Gemini) skipped due to transcription issue or no input."
try:
print(f"Gradio: Sending to Gemini: '{text_input}'")
full_prompt = f"User: {text_input}\nAssistant:"
response = gemini_model_instance_gradio.generate_content(full_prompt)
response_text = ""
if response.candidates and response.candidates[0].content.parts:
response_text = response.candidates[0].content.parts[0].text.strip()
else:
feedback_info = ""
if hasattr(response, 'prompt_feedback') and response.prompt_feedback:
feedback_info = f" Feedback: {response.prompt_feedback}"
print(f"Gradio: Gemini response did not contain expected content.{feedback_info}")
response_text = f"I'm sorry, I couldn't generate a response for that (Gemini).{feedback_info}"
print(f"Gradio: Gemini LLM Response: {response_text}")
return response_text if response_text else "Gemini LLM generated an empty response."
except Exception as e:
print(f"Gradio: Error during Gemini LLM generation: {e}")
import traceback
traceback.print_exc()
return f"Error during Gemini LLM generation: {str(e)}"
def synthesize_speech_gradio(text_input: str, description: str = "A clear, female voice speaking in English."):
if tts_model_gradio is None or tts_tokenizer_gradio is None:
return "Error: TTS model or its tokenizer not loaded."
if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "LLM skipped" in text_input or "generated an empty response" in text_input or "not configured" in text_input or "ASR result format not recognized" in text_input :
print(f"Gradio: Invalid input to TTS: '{text_input}'. Skipping synthesis.")
return "TTS skipped due to LLM issue or no input."
try:
print(f"Gradio: Synthesizing speech for: '{text_input}'")
description_tokenized = tts_tokenizer_gradio(description, return_tensors="pt", padding=True, truncation=True, max_length=128)
description_ids = description_tokenized.input_ids.to(DEVICE)
description_attention_mask = description_tokenized.attention_mask.to(DEVICE)
prompt_tokenized = tts_tokenizer_gradio(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
prompt_ids = prompt_tokenized.input_ids.to(DEVICE)
if prompt_ids.shape[-1] == 0: # Check if tokenized prompt is empty
print(f"Gradio: Tokenized prompt for TTS is empty. Text was: '{text_input}'. Skipping synthesis.")
return "TTS skipped: Input text resulted in empty tokens."
generation = tts_model_gradio.generate(
input_ids=description_ids,
attention_mask=description_attention_mask,
prompt_input_ids=prompt_ids,
do_sample=True, temperature=0.7, top_k=50, top_p=0.95
).cpu().numpy().squeeze()
sampling_rate = tts_model_gradio.config.sampling_rate
print(f"Gradio: Speech synthesized. Array shape: {generation.shape}, Sample rate: {sampling_rate}")
return (sampling_rate, generation)
except Exception as e:
print(f"Gradio: Error during speech synthesis: {e}")
import traceback
traceback.print_exc()
if "You need to specify either `text` or `text_target`" in str(e):
return "Error in TTS: Model requires 'text' or 'text_target'. Input might be too short or problematic."
return f"Error during speech synthesis: {str(e)}"
# --- Gradio Interface Definition ---
load_all_resources_gradio()
def full_pipeline_gradio(audio_input):
transcribed_text_output = transcribe_audio_gradio(audio_input)
print(f"DEBUG full_pipeline_gradio - Step 1 (Transcription): '{transcribed_text_output}' (type: {type(transcribed_text_output)})")
llm_response_text_output = generate_gemini_response_gradio(transcribed_text_output)
print(f"DEBUG full_pipeline_gradio - Step 2 (LLM Response): '{llm_response_text_output}' (type: {type(llm_response_text_output)})")
tts_synthesis_result = synthesize_speech_gradio(llm_response_text_output)
final_audio_output = None
if isinstance(tts_synthesis_result, tuple) and len(tts_synthesis_result) == 2 and isinstance(tts_synthesis_result[1], np.ndarray):
final_audio_output = tts_synthesis_result
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Success): Audio tuple with shape {tts_synthesis_result[1].shape if isinstance(tts_synthesis_result[1], np.ndarray) else 'N/A'}")
else:
error_message_from_tts = str(tts_synthesis_result) if isinstance(tts_synthesis_result, str) else "TTS synthesis failed or returned unexpected type"
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Failed/Non-audio): {error_message_from_tts}. Providing silent audio.")
# Append TTS error to LLM text only if LLM text was valid
if llm_response_text_output and not llm_response_text_output.startswith("Error:") and "LLM skipped" not in llm_response_text_output and "ASR result format not recognized" not in llm_response_text_output:
llm_response_text_output = f"{llm_response_text_output} | (TTS Problem: {error_message_from_tts})"
elif not llm_response_text_output or llm_response_text_output.startswith("Error:") or "LLM skipped" in llm_response_text_output or "ASR result format not recognized" in llm_response_text_output:
# If LLM already had an error, just keep that error, maybe note TTS also had an issue
llm_response_text_output = f"{llm_response_text_output} (TTS also had an issue: {error_message_from_tts})"
default_sample_rate = tts_model_gradio.config.sampling_rate if tts_model_gradio and hasattr(tts_model_gradio, 'config') else TARGET_SAMPLE_RATE
final_audio_output = (default_sample_rate, np.array([0.0], dtype=np.float32))
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Fallback): Silent audio tuple")
print(f"DEBUG full_pipeline_gradio - RETURNING: Transcription='{transcribed_text_output}', LLM_Text='{llm_response_text_output}', Audio_Type={type(final_audio_output)}")
return transcribed_text_output, llm_response_text_output, final_audio_output
with gr.Blocks(title="Conversational AI Demo") as demo:
gr.Markdown("# Conversational AI Demo (STT -> Gemini LLM -> TTS)")
with gr.Row():
audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Speak Here")
process_button = gr.Button("Process Audio")
with gr.Accordion("Outputs", open=True):
transcription_out = gr.Textbox(label="You Said (Transcription)", lines=2)
llm_response_out = gr.Textbox(label="Gemini Assistant Says (Text)", lines=5)
audio_out = gr.Audio(label="Assistant Says (Audio)")
process_button.click(
fn=full_pipeline_gradio,
inputs=[audio_in],
outputs=[transcription_out, llm_response_out, audio_out]
)
gr.Markdown("---")
gr.Markdown("### How to Use:")
gr.Markdown("1. Ensure your `GEMINI_API_KEY` environment variable is set.")
gr.Markdown("2. Click into the 'Speak Here' box and record your audio.")
gr.Markdown("3. Click the 'Process Audio' button.")
gr.Markdown("4. View the transcription, Gemini's text response, and listen to the audio response.")
if __name__ == "__main__":
demo.launch(share=False) |