Spaces:
Sleeping
Sleeping
jeremierostan
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
import uuid
|
6 |
+
import torch
|
7 |
+
import cohere
|
8 |
+
import secrets
|
9 |
+
import fasttext
|
10 |
+
import requests
|
11 |
+
from groq import Groq
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from typing import Optional, List, Tuple, Any
|
14 |
+
from huggingface_hub import hf_hub_download
|
15 |
+
from functools import lru_cache
|
16 |
+
import gradio as gr
|
17 |
+
|
18 |
+
# Configuration
|
19 |
+
@dataclass
|
20 |
+
class Config:
|
21 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
batch_size: int = 32
|
23 |
+
model_name: str = "aya-expanse-32B"
|
24 |
+
|
25 |
+
# API Keys from environment
|
26 |
+
groq_api_key: str = os.getenv("GROQ_API_KEY")
|
27 |
+
chat_cohere_api_key: str = os.getenv("CHAT_COHERE_API_KEY")
|
28 |
+
|
29 |
+
# Neets API key will be set via user input
|
30 |
+
neets_ai_api_key: Optional[str] = None
|
31 |
+
|
32 |
+
def set_neets_key(self, key: str):
|
33 |
+
"""Update Neets API key."""
|
34 |
+
self.neets_ai_api_key = key
|
35 |
+
|
36 |
+
config = Config()
|
37 |
+
|
38 |
+
# Initialize clients
|
39 |
+
def get_clients():
|
40 |
+
return {
|
41 |
+
'chat': cohere.Client(
|
42 |
+
api_key=config.chat_cohere_api_key,
|
43 |
+
client_name="c4ai-aya-expanse-chat"
|
44 |
+
),
|
45 |
+
'groq': Groq(api_key=config.groq_api_key)
|
46 |
+
}
|
47 |
+
|
48 |
+
clients = get_clients()
|
49 |
+
|
50 |
+
# Language identification
|
51 |
+
@lru_cache(maxsize=1)
|
52 |
+
def load_lid_model():
|
53 |
+
"""Load and cache the language identification model."""
|
54 |
+
lid_model_path = hf_hub_download(
|
55 |
+
repo_id="facebook/fasttext-language-identification",
|
56 |
+
filename="model.bin"
|
57 |
+
)
|
58 |
+
return fasttext.load_model(lid_model_path)
|
59 |
+
|
60 |
+
def predict_language(text: str) -> str:
|
61 |
+
"""Predict language of input text using FastText model."""
|
62 |
+
if not text:
|
63 |
+
return "eng_Latn" # Default to English
|
64 |
+
text = re.sub("\n", " ", text)
|
65 |
+
model = load_lid_model()
|
66 |
+
label, _ = model.predict(text)
|
67 |
+
return label[0][len("__label__"):]
|
68 |
+
|
69 |
+
def clean_text(text: str, remove_bullets: bool = False, remove_newline: bool = False) -> str:
|
70 |
+
"""Clean text by removing formatting and optional elements."""
|
71 |
+
if not text:
|
72 |
+
return ""
|
73 |
+
|
74 |
+
text = re.sub(r"\*\*", "", text)
|
75 |
+
if remove_bullets:
|
76 |
+
text = re.sub(r"^- ", "", text, flags=re.MULTILINE)
|
77 |
+
if remove_newline:
|
78 |
+
text = re.sub(r"\n", " ", text)
|
79 |
+
|
80 |
+
return text.strip()
|
81 |
+
|
82 |
+
class ConversationManager:
|
83 |
+
"""Manages the entire conversation flow including voice, text, and memory."""
|
84 |
+
|
85 |
+
def __init__(self):
|
86 |
+
self.chat_client = clients['chat']
|
87 |
+
|
88 |
+
def check_neets_key(self) -> bool:
|
89 |
+
"""Check if Neets API key is set."""
|
90 |
+
return bool(config.neets_ai_api_key)
|
91 |
+
|
92 |
+
def transcribe_audio(self, audio_file: str) -> Tuple[str, str]:
|
93 |
+
"""Transcribe audio to text and detect language."""
|
94 |
+
if not audio_file:
|
95 |
+
return "", "eng_Latn"
|
96 |
+
|
97 |
+
# Transcribe using Whisper
|
98 |
+
with open(audio_file, "rb") as f:
|
99 |
+
transcription = clients['groq'].audio.transcriptions.create(
|
100 |
+
file=(audio_file, f.read()),
|
101 |
+
model="whisper-large-v3-turbo",
|
102 |
+
response_format="json",
|
103 |
+
temperature=0.0
|
104 |
+
)
|
105 |
+
|
106 |
+
text = transcription.text
|
107 |
+
lang_code = predict_language(text)
|
108 |
+
|
109 |
+
return text, lang_code
|
110 |
+
|
111 |
+
def generate_response(
|
112 |
+
self,
|
113 |
+
user_input: str,
|
114 |
+
chat_history: List[Tuple[str, str]],
|
115 |
+
conversation_id: str = None
|
116 |
+
) -> Tuple[List[Tuple[str, str]], str, str]:
|
117 |
+
"""Generate assistant's response based on user input and conversation history."""
|
118 |
+
if not conversation_id:
|
119 |
+
conversation_id = str(uuid.uuid4())
|
120 |
+
|
121 |
+
# Format history for the model
|
122 |
+
formatted_history = []
|
123 |
+
for human, assistant in chat_history:
|
124 |
+
formatted_history.extend([human, assistant])
|
125 |
+
|
126 |
+
# Generate response
|
127 |
+
stream = self.chat_client.chat_stream(
|
128 |
+
message=user_input,
|
129 |
+
preamble=CHAT_PREAMBLE,
|
130 |
+
conversation_id=conversation_id,
|
131 |
+
model=config.model_name,
|
132 |
+
temperature=0.3,
|
133 |
+
chat_history=formatted_history
|
134 |
+
)
|
135 |
+
|
136 |
+
# Collect response
|
137 |
+
response = ""
|
138 |
+
for event in stream:
|
139 |
+
if event.event_type == "text-generation":
|
140 |
+
response += event.text
|
141 |
+
# Update chat history
|
142 |
+
if chat_history and isinstance(chat_history[-1], tuple):
|
143 |
+
chat_history = chat_history[:-1] + [(user_input, response)]
|
144 |
+
else:
|
145 |
+
chat_history.append((user_input, response))
|
146 |
+
yield chat_history, response, conversation_id
|
147 |
+
|
148 |
+
return chat_history, response, conversation_id
|
149 |
+
|
150 |
+
def text_to_speech(self, text: str, lang_code: str) -> str:
|
151 |
+
"""Convert text to speech using Neets.ai."""
|
152 |
+
if not text:
|
153 |
+
return None
|
154 |
+
|
155 |
+
if not self.check_neets_key():
|
156 |
+
raise ValueError("Neets API key not set. Please enter your API key.")
|
157 |
+
|
158 |
+
# Get language mapping for Neets.ai
|
159 |
+
neets_lang_id = NEETS_AI_LANGID_MAP.get(lang_code, "en")
|
160 |
+
neets_vits_voice_id = f"vits-{neets_lang_id}"
|
161 |
+
|
162 |
+
response = requests.post(
|
163 |
+
url="https://api.neets.ai/v1/tts",
|
164 |
+
headers={
|
165 |
+
"Content-Type": "application/json",
|
166 |
+
"X-API-Key": config.neets_ai_api_key
|
167 |
+
},
|
168 |
+
json={
|
169 |
+
"text": text,
|
170 |
+
"voice_id": neets_vits_voice_id,
|
171 |
+
"params": {"model": "vits"}
|
172 |
+
}
|
173 |
+
)
|
174 |
+
|
175 |
+
if response.status_code != 200:
|
176 |
+
raise ValueError(f"Neets API error: {response.text}")
|
177 |
+
|
178 |
+
audio_path = f"neets_response_{uuid.uuid4()}.mp3"
|
179 |
+
with open(audio_path, "wb") as f:
|
180 |
+
f.write(response.content)
|
181 |
+
return audio_path
|
182 |
+
|
183 |
+
def clear_conversation(self) -> Tuple[List, str]:
|
184 |
+
"""Clear the conversation history."""
|
185 |
+
return [], str(uuid.uuid4())
|
186 |
+
|
187 |
+
def create_gradio_interface():
|
188 |
+
"""Create the Gradio interface for the conversational AI system."""
|
189 |
+
|
190 |
+
theme = gr.themes.Base(
|
191 |
+
primary_hue=gr.themes.colors.teal,
|
192 |
+
secondary_hue=gr.themes.colors.blue,
|
193 |
+
neutral_hue=gr.themes.colors.gray,
|
194 |
+
text_size=gr.themes.sizes.text_lg,
|
195 |
+
).set(
|
196 |
+
button_primary_background_fill="#114A56",
|
197 |
+
button_primary_background_fill_hover="#114A56",
|
198 |
+
block_title_text_weight="600",
|
199 |
+
block_label_text_weight="600",
|
200 |
+
block_label_text_size="*text_md",
|
201 |
+
)
|
202 |
+
|
203 |
+
conversation_manager = ConversationManager()
|
204 |
+
|
205 |
+
with gr.Blocks(theme=theme, analytics_enabled=False) as demo:
|
206 |
+
# Header
|
207 |
+
with gr.Row():
|
208 |
+
gr.Markdown("""
|
209 |
+
# Multilingual Voice Chat Assistant
|
210 |
+
Have a natural conversation with Aya using voice or text in any of 23 supported languages.
|
211 |
+
""")
|
212 |
+
|
213 |
+
# API Key input
|
214 |
+
with gr.Row():
|
215 |
+
with gr.Column(scale=1):
|
216 |
+
neets_key = gr.Textbox(
|
217 |
+
type="password",
|
218 |
+
label="Enter your Neets.ai API Key",
|
219 |
+
placeholder="Enter API key here...",
|
220 |
+
show_label=True
|
221 |
+
)
|
222 |
+
api_status = gr.Markdown("API Key Status: Not Set")
|
223 |
+
|
224 |
+
def update_api_key(key):
|
225 |
+
if not key:
|
226 |
+
return "API Key Status: Not Set"
|
227 |
+
config.set_neets_key(key)
|
228 |
+
return "API Key Status: Set ✓"
|
229 |
+
|
230 |
+
neets_key.change(
|
231 |
+
update_api_key,
|
232 |
+
inputs=[neets_key],
|
233 |
+
outputs=[api_status]
|
234 |
+
)
|
235 |
+
|
236 |
+
# State management
|
237 |
+
conversation_id = gr.State("")
|
238 |
+
current_language = gr.State("eng_Latn")
|
239 |
+
|
240 |
+
with gr.Row():
|
241 |
+
# Chat interface
|
242 |
+
with gr.Column(scale=2):
|
243 |
+
chatbot = gr.Chatbot(
|
244 |
+
show_label=False,
|
245 |
+
show_copy_button=True,
|
246 |
+
height=400,
|
247 |
+
label="Conversation"
|
248 |
+
)
|
249 |
+
|
250 |
+
# Input options
|
251 |
+
with gr.Row():
|
252 |
+
# Text input
|
253 |
+
text_input = gr.Textbox(
|
254 |
+
placeholder="Type your message or use voice input...",
|
255 |
+
label="Text Input",
|
256 |
+
lines=2
|
257 |
+
)
|
258 |
+
|
259 |
+
# Voice input
|
260 |
+
audio_input = gr.Audio(
|
261 |
+
source="microphone",
|
262 |
+
type="filepath",
|
263 |
+
label="Voice Input"
|
264 |
+
)
|
265 |
+
|
266 |
+
with gr.Row():
|
267 |
+
submit_btn = gr.Button("Send Message", variant="primary")
|
268 |
+
clear_btn = gr.Button("Clear Conversation")
|
269 |
+
|
270 |
+
# Audio output and info
|
271 |
+
with gr.Column(scale=1):
|
272 |
+
response_audio = gr.Audio(
|
273 |
+
label="Assistant's Voice Response",
|
274 |
+
type="filepath"
|
275 |
+
)
|
276 |
+
|
277 |
+
detected_language = gr.Markdown(
|
278 |
+
"Detected Language: English",
|
279 |
+
label="Language Info"
|
280 |
+
)
|
281 |
+
|
282 |
+
# Event handlers
|
283 |
+
def process_input(
|
284 |
+
input_text: str,
|
285 |
+
input_audio: str,
|
286 |
+
history: List[Tuple[str, str]],
|
287 |
+
conv_id: str,
|
288 |
+
neets_key: str
|
289 |
+
):
|
290 |
+
if not neets_key:
|
291 |
+
raise gr.Error("Please enter your Neets.ai API key first")
|
292 |
+
|
293 |
+
# Determine input source
|
294 |
+
if input_audio:
|
295 |
+
user_text, lang_code = conversation_manager.transcribe_audio(input_audio)
|
296 |
+
else:
|
297 |
+
user_text = input_text
|
298 |
+
lang_code = predict_language(user_text)
|
299 |
+
|
300 |
+
# Update language display
|
301 |
+
lang_name = LID_LANGUAGES.get(lang_code, "Unknown")
|
302 |
+
detected_language.update(f"Detected Language: {lang_name}")
|
303 |
+
|
304 |
+
# Generate response
|
305 |
+
new_history, response, new_conv_id = conversation_manager.generate_response(
|
306 |
+
user_text, history, conv_id
|
307 |
+
)
|
308 |
+
|
309 |
+
try:
|
310 |
+
# Generate audio response
|
311 |
+
audio_path = conversation_manager.text_to_speech(response, lang_code)
|
312 |
+
except ValueError as e:
|
313 |
+
raise gr.Error(str(e))
|
314 |
+
|
315 |
+
return new_history, new_conv_id, audio_path
|
316 |
+
|
317 |
+
# Connect event handlers
|
318 |
+
submit_btn.click(
|
319 |
+
process_input,
|
320 |
+
inputs=[
|
321 |
+
text_input,
|
322 |
+
audio_input,
|
323 |
+
chatbot,
|
324 |
+
conversation_id,
|
325 |
+
neets_key
|
326 |
+
],
|
327 |
+
outputs=[
|
328 |
+
chatbot,
|
329 |
+
conversation_id,
|
330 |
+
response_audio
|
331 |
+
]
|
332 |
+
)
|
333 |
+
|
334 |
+
# Also trigger on text input enter
|
335 |
+
text_input.submit(
|
336 |
+
process_input,
|
337 |
+
inputs=[
|
338 |
+
text_input,
|
339 |
+
audio_input,
|
340 |
+
chatbot,
|
341 |
+
conversation_id,
|
342 |
+
neets_key
|
343 |
+
],
|
344 |
+
outputs=[
|
345 |
+
chatbot,
|
346 |
+
conversation_id,
|
347 |
+
response_audio
|
348 |
+
]
|
349 |
+
)
|
350 |
+
|
351 |
+
# Clear conversation
|
352 |
+
clear_btn.click(
|
353 |
+
conversation_manager.clear_conversation,
|
354 |
+
outputs=[chatbot, conversation_id]
|
355 |
+
)
|
356 |
+
|
357 |
+
# Clear inputs after submission
|
358 |
+
submit_btn.click(lambda: "", None, text_input)
|
359 |
+
submit_btn.click(lambda: None, None, audio_input)
|
360 |
+
|
361 |
+
return demo
|
362 |
+
|
363 |
+
if __name__ == "__main__":
|
364 |
+
demo = create_gradio_interface()
|
365 |
+
demo.queue(
|
366 |
+
api_open=False,
|
367 |
+
max_size=20,
|
368 |
+
default_concurrency_limit=4
|
369 |
+
).launch(
|
370 |
+
show_api=False,
|
371 |
+
allowed_paths=['/home/user/app']
|
372 |
+
)
|