jeremierostan commited on
Commit
48cf021
·
verified ·
1 Parent(s): d130f0b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -0
app.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ import re
4
+ import time
5
+ import uuid
6
+ import torch
7
+ import cohere
8
+ import secrets
9
+ import fasttext
10
+ import requests
11
+ from groq import Groq
12
+ from dataclasses import dataclass
13
+ from typing import Optional, List, Tuple, Any
14
+ from huggingface_hub import hf_hub_download
15
+ from functools import lru_cache
16
+ import gradio as gr
17
+
18
+ # Configuration
19
+ @dataclass
20
+ class Config:
21
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
22
+ batch_size: int = 32
23
+ model_name: str = "aya-expanse-32B"
24
+
25
+ # API Keys from environment
26
+ groq_api_key: str = os.getenv("GROQ_API_KEY")
27
+ chat_cohere_api_key: str = os.getenv("CHAT_COHERE_API_KEY")
28
+
29
+ # Neets API key will be set via user input
30
+ neets_ai_api_key: Optional[str] = None
31
+
32
+ def set_neets_key(self, key: str):
33
+ """Update Neets API key."""
34
+ self.neets_ai_api_key = key
35
+
36
+ config = Config()
37
+
38
+ # Initialize clients
39
+ def get_clients():
40
+ return {
41
+ 'chat': cohere.Client(
42
+ api_key=config.chat_cohere_api_key,
43
+ client_name="c4ai-aya-expanse-chat"
44
+ ),
45
+ 'groq': Groq(api_key=config.groq_api_key)
46
+ }
47
+
48
+ clients = get_clients()
49
+
50
+ # Language identification
51
+ @lru_cache(maxsize=1)
52
+ def load_lid_model():
53
+ """Load and cache the language identification model."""
54
+ lid_model_path = hf_hub_download(
55
+ repo_id="facebook/fasttext-language-identification",
56
+ filename="model.bin"
57
+ )
58
+ return fasttext.load_model(lid_model_path)
59
+
60
+ def predict_language(text: str) -> str:
61
+ """Predict language of input text using FastText model."""
62
+ if not text:
63
+ return "eng_Latn" # Default to English
64
+ text = re.sub("\n", " ", text)
65
+ model = load_lid_model()
66
+ label, _ = model.predict(text)
67
+ return label[0][len("__label__"):]
68
+
69
+ def clean_text(text: str, remove_bullets: bool = False, remove_newline: bool = False) -> str:
70
+ """Clean text by removing formatting and optional elements."""
71
+ if not text:
72
+ return ""
73
+
74
+ text = re.sub(r"\*\*", "", text)
75
+ if remove_bullets:
76
+ text = re.sub(r"^- ", "", text, flags=re.MULTILINE)
77
+ if remove_newline:
78
+ text = re.sub(r"\n", " ", text)
79
+
80
+ return text.strip()
81
+
82
+ class ConversationManager:
83
+ """Manages the entire conversation flow including voice, text, and memory."""
84
+
85
+ def __init__(self):
86
+ self.chat_client = clients['chat']
87
+
88
+ def check_neets_key(self) -> bool:
89
+ """Check if Neets API key is set."""
90
+ return bool(config.neets_ai_api_key)
91
+
92
+ def transcribe_audio(self, audio_file: str) -> Tuple[str, str]:
93
+ """Transcribe audio to text and detect language."""
94
+ if not audio_file:
95
+ return "", "eng_Latn"
96
+
97
+ # Transcribe using Whisper
98
+ with open(audio_file, "rb") as f:
99
+ transcription = clients['groq'].audio.transcriptions.create(
100
+ file=(audio_file, f.read()),
101
+ model="whisper-large-v3-turbo",
102
+ response_format="json",
103
+ temperature=0.0
104
+ )
105
+
106
+ text = transcription.text
107
+ lang_code = predict_language(text)
108
+
109
+ return text, lang_code
110
+
111
+ def generate_response(
112
+ self,
113
+ user_input: str,
114
+ chat_history: List[Tuple[str, str]],
115
+ conversation_id: str = None
116
+ ) -> Tuple[List[Tuple[str, str]], str, str]:
117
+ """Generate assistant's response based on user input and conversation history."""
118
+ if not conversation_id:
119
+ conversation_id = str(uuid.uuid4())
120
+
121
+ # Format history for the model
122
+ formatted_history = []
123
+ for human, assistant in chat_history:
124
+ formatted_history.extend([human, assistant])
125
+
126
+ # Generate response
127
+ stream = self.chat_client.chat_stream(
128
+ message=user_input,
129
+ preamble=CHAT_PREAMBLE,
130
+ conversation_id=conversation_id,
131
+ model=config.model_name,
132
+ temperature=0.3,
133
+ chat_history=formatted_history
134
+ )
135
+
136
+ # Collect response
137
+ response = ""
138
+ for event in stream:
139
+ if event.event_type == "text-generation":
140
+ response += event.text
141
+ # Update chat history
142
+ if chat_history and isinstance(chat_history[-1], tuple):
143
+ chat_history = chat_history[:-1] + [(user_input, response)]
144
+ else:
145
+ chat_history.append((user_input, response))
146
+ yield chat_history, response, conversation_id
147
+
148
+ return chat_history, response, conversation_id
149
+
150
+ def text_to_speech(self, text: str, lang_code: str) -> str:
151
+ """Convert text to speech using Neets.ai."""
152
+ if not text:
153
+ return None
154
+
155
+ if not self.check_neets_key():
156
+ raise ValueError("Neets API key not set. Please enter your API key.")
157
+
158
+ # Get language mapping for Neets.ai
159
+ neets_lang_id = NEETS_AI_LANGID_MAP.get(lang_code, "en")
160
+ neets_vits_voice_id = f"vits-{neets_lang_id}"
161
+
162
+ response = requests.post(
163
+ url="https://api.neets.ai/v1/tts",
164
+ headers={
165
+ "Content-Type": "application/json",
166
+ "X-API-Key": config.neets_ai_api_key
167
+ },
168
+ json={
169
+ "text": text,
170
+ "voice_id": neets_vits_voice_id,
171
+ "params": {"model": "vits"}
172
+ }
173
+ )
174
+
175
+ if response.status_code != 200:
176
+ raise ValueError(f"Neets API error: {response.text}")
177
+
178
+ audio_path = f"neets_response_{uuid.uuid4()}.mp3"
179
+ with open(audio_path, "wb") as f:
180
+ f.write(response.content)
181
+ return audio_path
182
+
183
+ def clear_conversation(self) -> Tuple[List, str]:
184
+ """Clear the conversation history."""
185
+ return [], str(uuid.uuid4())
186
+
187
+ def create_gradio_interface():
188
+ """Create the Gradio interface for the conversational AI system."""
189
+
190
+ theme = gr.themes.Base(
191
+ primary_hue=gr.themes.colors.teal,
192
+ secondary_hue=gr.themes.colors.blue,
193
+ neutral_hue=gr.themes.colors.gray,
194
+ text_size=gr.themes.sizes.text_lg,
195
+ ).set(
196
+ button_primary_background_fill="#114A56",
197
+ button_primary_background_fill_hover="#114A56",
198
+ block_title_text_weight="600",
199
+ block_label_text_weight="600",
200
+ block_label_text_size="*text_md",
201
+ )
202
+
203
+ conversation_manager = ConversationManager()
204
+
205
+ with gr.Blocks(theme=theme, analytics_enabled=False) as demo:
206
+ # Header
207
+ with gr.Row():
208
+ gr.Markdown("""
209
+ # Multilingual Voice Chat Assistant
210
+ Have a natural conversation with Aya using voice or text in any of 23 supported languages.
211
+ """)
212
+
213
+ # API Key input
214
+ with gr.Row():
215
+ with gr.Column(scale=1):
216
+ neets_key = gr.Textbox(
217
+ type="password",
218
+ label="Enter your Neets.ai API Key",
219
+ placeholder="Enter API key here...",
220
+ show_label=True
221
+ )
222
+ api_status = gr.Markdown("API Key Status: Not Set")
223
+
224
+ def update_api_key(key):
225
+ if not key:
226
+ return "API Key Status: Not Set"
227
+ config.set_neets_key(key)
228
+ return "API Key Status: Set ✓"
229
+
230
+ neets_key.change(
231
+ update_api_key,
232
+ inputs=[neets_key],
233
+ outputs=[api_status]
234
+ )
235
+
236
+ # State management
237
+ conversation_id = gr.State("")
238
+ current_language = gr.State("eng_Latn")
239
+
240
+ with gr.Row():
241
+ # Chat interface
242
+ with gr.Column(scale=2):
243
+ chatbot = gr.Chatbot(
244
+ show_label=False,
245
+ show_copy_button=True,
246
+ height=400,
247
+ label="Conversation"
248
+ )
249
+
250
+ # Input options
251
+ with gr.Row():
252
+ # Text input
253
+ text_input = gr.Textbox(
254
+ placeholder="Type your message or use voice input...",
255
+ label="Text Input",
256
+ lines=2
257
+ )
258
+
259
+ # Voice input
260
+ audio_input = gr.Audio(
261
+ source="microphone",
262
+ type="filepath",
263
+ label="Voice Input"
264
+ )
265
+
266
+ with gr.Row():
267
+ submit_btn = gr.Button("Send Message", variant="primary")
268
+ clear_btn = gr.Button("Clear Conversation")
269
+
270
+ # Audio output and info
271
+ with gr.Column(scale=1):
272
+ response_audio = gr.Audio(
273
+ label="Assistant's Voice Response",
274
+ type="filepath"
275
+ )
276
+
277
+ detected_language = gr.Markdown(
278
+ "Detected Language: English",
279
+ label="Language Info"
280
+ )
281
+
282
+ # Event handlers
283
+ def process_input(
284
+ input_text: str,
285
+ input_audio: str,
286
+ history: List[Tuple[str, str]],
287
+ conv_id: str,
288
+ neets_key: str
289
+ ):
290
+ if not neets_key:
291
+ raise gr.Error("Please enter your Neets.ai API key first")
292
+
293
+ # Determine input source
294
+ if input_audio:
295
+ user_text, lang_code = conversation_manager.transcribe_audio(input_audio)
296
+ else:
297
+ user_text = input_text
298
+ lang_code = predict_language(user_text)
299
+
300
+ # Update language display
301
+ lang_name = LID_LANGUAGES.get(lang_code, "Unknown")
302
+ detected_language.update(f"Detected Language: {lang_name}")
303
+
304
+ # Generate response
305
+ new_history, response, new_conv_id = conversation_manager.generate_response(
306
+ user_text, history, conv_id
307
+ )
308
+
309
+ try:
310
+ # Generate audio response
311
+ audio_path = conversation_manager.text_to_speech(response, lang_code)
312
+ except ValueError as e:
313
+ raise gr.Error(str(e))
314
+
315
+ return new_history, new_conv_id, audio_path
316
+
317
+ # Connect event handlers
318
+ submit_btn.click(
319
+ process_input,
320
+ inputs=[
321
+ text_input,
322
+ audio_input,
323
+ chatbot,
324
+ conversation_id,
325
+ neets_key
326
+ ],
327
+ outputs=[
328
+ chatbot,
329
+ conversation_id,
330
+ response_audio
331
+ ]
332
+ )
333
+
334
+ # Also trigger on text input enter
335
+ text_input.submit(
336
+ process_input,
337
+ inputs=[
338
+ text_input,
339
+ audio_input,
340
+ chatbot,
341
+ conversation_id,
342
+ neets_key
343
+ ],
344
+ outputs=[
345
+ chatbot,
346
+ conversation_id,
347
+ response_audio
348
+ ]
349
+ )
350
+
351
+ # Clear conversation
352
+ clear_btn.click(
353
+ conversation_manager.clear_conversation,
354
+ outputs=[chatbot, conversation_id]
355
+ )
356
+
357
+ # Clear inputs after submission
358
+ submit_btn.click(lambda: "", None, text_input)
359
+ submit_btn.click(lambda: None, None, audio_input)
360
+
361
+ return demo
362
+
363
+ if __name__ == "__main__":
364
+ demo = create_gradio_interface()
365
+ demo.queue(
366
+ api_open=False,
367
+ max_size=20,
368
+ default_concurrency_limit=4
369
+ ).launch(
370
+ show_api=False,
371
+ allowed_paths=['/home/user/app']
372
+ )