File size: 17,105 Bytes
c2ac364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# gradio_app.py
import gradio as gr
import io
import os
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoModel # CHANGED: Using AutoModel as per model card
import numpy as np
import google.generativeai as genai
import asyncio
import librosa
import torchaudio # Often used by models like this for audio loading/processing internally or as input type

# --- Configuration ---
ASR_MODEL_NAME = "ai4bharat/indic-conformer-600m-multilingual"
TARGET_SAMPLE_RATE = 16000 # Model expects 16kHz

TTS_MODEL_NAME = "ai4bharat/indic-parler-tts"
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyD6x3Yoby4eQ6QL2kaaG_Rz3fG3rh7wPB8")
GEMINI_MODEL_NAME_GRADIO = "gemini-1.5-flash-latest"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# torch_dtype for ParlerTTS, Gemini etc. For ASR model, it might handle its own precision.

# --- Global Model Variables ---
asr_model_gradio = None # This will be the AutoModel instance

gemini_model_instance_gradio = None
tts_model_gradio = None
tts_tokenizer_gradio = None # For ParlerTTS

# --- Model Loading & API Configuration ---
def load_all_resources_gradio():
    global asr_model_gradio, tts_model_gradio, tts_tokenizer_gradio, gemini_model_instance_gradio
    print(f"Gradio: Loading resources. ASR will be on device: {DEVICE}")

    if asr_model_gradio is None:
        print(f"Gradio: Loading ASR model: {ASR_MODEL_NAME} using AutoModel")
        try:
            # Load using AutoModel as per the model card's implication
            asr_model_gradio = AutoModel.from_pretrained(ASR_MODEL_NAME, trust_remote_code=True)
            asr_model_gradio.to(DEVICE) # Move model to device
            # The model might handle its own precision (e.g. .half()) internally if `trust_remote_code` allows
            # Or you might need to call asr_model_gradio.half() if it supports it and you're on CUDA.
            if DEVICE == "cuda" and hasattr(asr_model_gradio, 'half'):
                print("Gradio: Applying .half() to ASR model.")
                asr_model_gradio.half()
            asr_model_gradio.eval()
            print(f"Gradio: ASR model ({ASR_MODEL_NAME}) loaded using AutoModel.")
        except Exception as e:
            print(f"Gradio: Failed to load ASR model {ASR_MODEL_NAME} using AutoModel: {e}")
            import traceback
            traceback.print_exc()
            asr_model_gradio = None

    if tts_model_gradio is None: # ParlerTTS loading
        print(f"Gradio: Loading IndicParler-TTS model: {TTS_MODEL_NAME}")
        # Ensure ParlerTTS specific tokenizer is loaded for TTS
        # Note: ASR model might have its own internal tokenizer/processor handled by its custom code
        tts_parler_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True)
        tts_model_gradio = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True).to(DEVICE)
        tts_tokenizer_gradio = tts_parler_tokenizer
        print("Gradio: IndicParler-TTS model loaded.")

    if gemini_model_instance_gradio is None: # Gemini loading
        if not GEMINI_API_KEY:
            print("Gradio: GEMINI_API_KEY not found. LLM functionality via Gemini will be limited.")
        else:
            try:
                genai.configure(api_key=GEMINI_API_KEY)
                gemini_model_instance_gradio = genai.GenerativeModel(GEMINI_MODEL_NAME_GRADIO)
                print(f"Gradio: Gemini API configured with model: {GEMINI_MODEL_NAME_GRADIO}")
            except Exception as e:
                print(f"Gradio: Failed to configure Gemini API: {e}")
                gemini_model_instance_gradio = None
    
    print("Gradio: All resources loaded (or attempted).")


# --- Helper Functions ---
def transcribe_audio_gradio(audio_input_tuple):
    if asr_model_gradio is None:
        return f"Error: ASR model ({ASR_MODEL_NAME}) not loaded."
    
    if audio_input_tuple is None:
        print("Gradio: No audio provided to transcribe_audio_gradio.")
        return "No audio provided."
        
    sample_rate, audio_numpy = audio_input_tuple

    if audio_numpy is None or audio_numpy.size == 0:
        print("Gradio: Audio numpy array is empty.")
        return "Empty audio received."

    # Ensure audio is mono float32, which is a common expectation
    if audio_numpy.ndim > 1: 
        if audio_numpy.shape[0] == 2 and audio_numpy.ndim == 2:
            audio_numpy = librosa.to_mono(audio_numpy)
        elif audio_numpy.shape[1] == 2 and audio_numpy.ndim == 2:
            audio_numpy = np.mean(audio_numpy, axis=1)
    
    if audio_numpy.dtype != np.float32:
        if np.issubdtype(audio_numpy.dtype, np.integer):
            audio_numpy = audio_numpy.astype(np.float32) / np.iinfo(audio_numpy.dtype).max
        else:
            audio_numpy = audio_numpy.astype(np.float32)

    # Resample to TARGET_SAMPLE_RATE (16kHz)
    if sample_rate != TARGET_SAMPLE_RATE:
        print(f"Gradio: Resampling audio from {sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz.")
        try:
            audio_numpy = librosa.resample(y=audio_numpy, orig_sr=sample_rate, target_sr=TARGET_SAMPLE_RATE)
            # After resampling, the audio_numpy is at TARGET_SAMPLE_RATE
        except Exception as e:
            print(f"Gradio: Error during resampling: {e}")
            return f"Error during audio resampling: {str(e)}"
    
    try:
        print(f"Gradio: Preparing to transcribe with {ASR_MODEL_NAME}. Input audio shape: {audio_numpy.shape}")

        # The model card example `model(wav, "hi", "ctc")` implies it might take a waveform tensor.
        # We have a numpy array. We need to convert it to a PyTorch tensor.
        # The model card uses torchaudio.load which returns a tensor.
        # Let's convert our numpy array to a tensor and ensure it's on the correct device.
        
        # Ensure the audio_numpy is 1D as expected by many ASR models for a single channel
        if audio_numpy.ndim > 1:
            audio_numpy = audio_numpy.squeeze() # Attempt to remove singleton dimensions
        if audio_numpy.ndim > 1 : # If still more than 1D, problem
            print(f"Gradio: Audio numpy array for ASR has unexpected dimensions after processing: {audio_numpy.shape}")
            return "Error: Audio processing resulted in unexpected dimensions."

        wav_tensor = torch.from_numpy(audio_numpy).to(DEVICE)
        # The model might expect a batch dimension, e.g., [1, num_samples]
        if wav_tensor.ndim == 1:
            wav_tensor = wav_tensor.unsqueeze(0) 
        
        print(f"Gradio: Transcribing with {ASR_MODEL_NAME} using CTC. Input tensor shape: {wav_tensor.shape}")

        # Perform ASR with CTC decoding (you can choose "rnnt" if preferred and supported)
        # The language code "hi" is for Hindi. You might want to make this configurable
        # or see if the model supports language auto-detection if you pass None or omit it.
        # For now, assuming "hi" or that the model handles mixed language if lang_id is not strictly enforced.
        # The model card doesn't specify if language ID is optional or how auto-detection works.
        # Let's try "auto" or a common language like "en" or "hi" to start.
        # The model card indicates training on 22 languages, so it's multilingual.
        # If language_id is required, you'll need to provide it.
        # Let's assume for now we try with a common Indian language or let the model try to auto-detect if "auto" or None is valid.
        # The snippet "model(wav, "hi", "ctc")" is specific.
        
        # The `model()` call is synchronous. Gradio handles this in a thread.
        with torch.no_grad(): # Good practice for inference
            transcription_result = asr_model_gradio(wav_tensor, "hi", "ctc") # Using lang_id="hi" and strategy="ctc" as per example

        # The output format needs to be checked. The model card implies it's the transcribed string directly.
        # It might be a list of transcriptions if batching occurs, or a dict.
        if isinstance(transcription_result, list) and len(transcription_result) > 0:
            transcribed_text = transcription_result[0] # Assuming first result for non-batched input
        elif isinstance(transcription_result, str):
            transcribed_text = transcription_result
        else:
            print(f"Gradio: Unexpected ASR result format: {type(transcription_result)}, value: {transcription_result}")
            transcribed_text = "ASR result format not recognized."
            
        transcribed_text = transcribed_text.strip()
        print(f"Gradio: Transcription ({ASR_MODEL_NAME}, CTC): {transcribed_text}")
        return transcribed_text if transcribed_text else "Transcription resulted in empty text."
    except Exception as e:
        print(f"Gradio: Error during {ASR_MODEL_NAME} transcription (AutoModel callable): {e}")
        import traceback
        traceback.print_exc()
        return f"Error during transcription ({ASR_MODEL_NAME}): {str(e)}"


# ... (Gemini LLM and TTS functions remain the same) ...
def generate_gemini_response_gradio(text_input: str):
    if not gemini_model_instance_gradio:
        return "Error: Gemini LLM not configured or API key missing."
    if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "No audio provided" in text_input or "Transcription resulted in empty text" in text_input or "Empty audio received" in text_input or "ASR result format not recognized" in text_input:
        print(f"Gradio: Invalid input to Gemini: '{text_input}'. Skipping LLM response.")
        return "LLM (Gemini) skipped due to transcription issue or no input."
    try:
        print(f"Gradio: Sending to Gemini: '{text_input}'")
        full_prompt = f"User: {text_input}\nAssistant:"
        response = gemini_model_instance_gradio.generate_content(full_prompt)
        response_text = ""
        if response.candidates and response.candidates[0].content.parts:
            response_text = response.candidates[0].content.parts[0].text.strip()
        else:
            feedback_info = ""
            if hasattr(response, 'prompt_feedback') and response.prompt_feedback:
                feedback_info = f" Feedback: {response.prompt_feedback}"
            print(f"Gradio: Gemini response did not contain expected content.{feedback_info}")
            response_text = f"I'm sorry, I couldn't generate a response for that (Gemini).{feedback_info}"
            
        print(f"Gradio: Gemini LLM Response: {response_text}")
        return response_text if response_text else "Gemini LLM generated an empty response."
    except Exception as e:
        print(f"Gradio: Error during Gemini LLM generation: {e}")
        import traceback
        traceback.print_exc()
        return f"Error during Gemini LLM generation: {str(e)}"

def synthesize_speech_gradio(text_input: str, description: str = "A clear, female voice speaking in English."):
    if tts_model_gradio is None or tts_tokenizer_gradio is None:
        return "Error: TTS model or its tokenizer not loaded."
    if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "LLM skipped" in text_input or "generated an empty response" in text_input or "not configured" in text_input or "ASR result format not recognized" in text_input :
        print(f"Gradio: Invalid input to TTS: '{text_input}'. Skipping synthesis.")
        return "TTS skipped due to LLM issue or no input."
    try:
        print(f"Gradio: Synthesizing speech for: '{text_input}'")
        description_tokenized = tts_tokenizer_gradio(description, return_tensors="pt", padding=True, truncation=True, max_length=128)
        description_ids = description_tokenized.input_ids.to(DEVICE)
        description_attention_mask = description_tokenized.attention_mask.to(DEVICE)

        prompt_tokenized = tts_tokenizer_gradio(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
        prompt_ids = prompt_tokenized.input_ids.to(DEVICE)

        if prompt_ids.shape[-1] == 0: # Check if tokenized prompt is empty
             print(f"Gradio: Tokenized prompt for TTS is empty. Text was: '{text_input}'. Skipping synthesis.")
             return "TTS skipped: Input text resulted in empty tokens."


        generation = tts_model_gradio.generate(
            input_ids=description_ids,
            attention_mask=description_attention_mask,
            prompt_input_ids=prompt_ids,
            do_sample=True, temperature=0.7, top_k=50, top_p=0.95
        ).cpu().numpy().squeeze()
        
        sampling_rate = tts_model_gradio.config.sampling_rate
        print(f"Gradio: Speech synthesized. Array shape: {generation.shape}, Sample rate: {sampling_rate}")
        return (sampling_rate, generation)
    except Exception as e:
        print(f"Gradio: Error during speech synthesis: {e}")
        import traceback
        traceback.print_exc()
        if "You need to specify either `text` or `text_target`" in str(e):
             return "Error in TTS: Model requires 'text' or 'text_target'. Input might be too short or problematic."
        return f"Error during speech synthesis: {str(e)}"

# --- Gradio Interface Definition ---
load_all_resources_gradio() 

def full_pipeline_gradio(audio_input):
    transcribed_text_output = transcribe_audio_gradio(audio_input)
    print(f"DEBUG full_pipeline_gradio - Step 1 (Transcription): '{transcribed_text_output}' (type: {type(transcribed_text_output)})")
    llm_response_text_output = generate_gemini_response_gradio(transcribed_text_output)
    print(f"DEBUG full_pipeline_gradio - Step 2 (LLM Response): '{llm_response_text_output}' (type: {type(llm_response_text_output)})")
    tts_synthesis_result = synthesize_speech_gradio(llm_response_text_output)
    final_audio_output = None
    if isinstance(tts_synthesis_result, tuple) and len(tts_synthesis_result) == 2 and isinstance(tts_synthesis_result[1], np.ndarray):
        final_audio_output = tts_synthesis_result
        print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Success): Audio tuple with shape {tts_synthesis_result[1].shape if isinstance(tts_synthesis_result[1], np.ndarray) else 'N/A'}")
    else:
        error_message_from_tts = str(tts_synthesis_result) if isinstance(tts_synthesis_result, str) else "TTS synthesis failed or returned unexpected type"
        print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Failed/Non-audio): {error_message_from_tts}. Providing silent audio.")
        # Append TTS error to LLM text only if LLM text was valid
        if llm_response_text_output and not llm_response_text_output.startswith("Error:") and "LLM skipped" not in llm_response_text_output and "ASR result format not recognized" not in llm_response_text_output:
            llm_response_text_output = f"{llm_response_text_output} | (TTS Problem: {error_message_from_tts})"
        elif not llm_response_text_output or llm_response_text_output.startswith("Error:") or "LLM skipped" in llm_response_text_output or "ASR result format not recognized" in llm_response_text_output:
             # If LLM already had an error, just keep that error, maybe note TTS also had an issue
             llm_response_text_output = f"{llm_response_text_output} (TTS also had an issue: {error_message_from_tts})"
        
        default_sample_rate = tts_model_gradio.config.sampling_rate if tts_model_gradio and hasattr(tts_model_gradio, 'config') else TARGET_SAMPLE_RATE
        final_audio_output = (default_sample_rate, np.array([0.0], dtype=np.float32))
        print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Fallback): Silent audio tuple")
    print(f"DEBUG full_pipeline_gradio - RETURNING: Transcription='{transcribed_text_output}', LLM_Text='{llm_response_text_output}', Audio_Type={type(final_audio_output)}")
    return transcribed_text_output, llm_response_text_output, final_audio_output

with gr.Blocks(title="Conversational AI Demo") as demo:
    gr.Markdown("# Conversational AI Demo (STT -> Gemini LLM -> TTS)")
    with gr.Row():
        audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Speak Here")
    process_button = gr.Button("Process Audio")
    with gr.Accordion("Outputs", open=True):
        transcription_out = gr.Textbox(label="You Said (Transcription)", lines=2) 
        llm_response_out = gr.Textbox(label="Gemini Assistant Says (Text)", lines=5)
        audio_out = gr.Audio(label="Assistant Says (Audio)")

    process_button.click(
        fn=full_pipeline_gradio,
        inputs=[audio_in],
        outputs=[transcription_out, llm_response_out, audio_out]
    )
    gr.Markdown("---")
    gr.Markdown("### How to Use:")
    gr.Markdown("1. Ensure your `GEMINI_API_KEY` environment variable is set.")
    gr.Markdown("2. Click into the 'Speak Here' box and record your audio.")
    gr.Markdown("3. Click the 'Process Audio' button.")
    gr.Markdown("4. View the transcription, Gemini's text response, and listen to the audio response.")

if __name__ == "__main__":
    demo.launch(share=False)