Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import os | |
import re | |
import time | |
import uuid | |
import torch | |
import cohere | |
import secrets | |
import fasttext | |
import requests | |
import numpy as np # Added numpy import | |
from groq import Groq | |
from dataclasses import dataclass | |
from typing import Optional, List, Tuple, Any | |
from huggingface_hub import hf_hub_download | |
from functools import lru_cache | |
import gradio as gr | |
# Mapping of language codes to names and Neets voice IDs | |
LID_LANGUAGES = { | |
"eng_Latn": "English", | |
"spa_Latn": "Spanish", | |
"fra_Latn": "French", | |
# Add other languages as needed | |
} | |
NEETS_AI_LANGID_MAP = { | |
"eng_Latn": "en", | |
"spa_Latn": "es", | |
"fra_Latn": "fr", | |
# Add other mappings as needed | |
} | |
# Chat preamble | |
CHAT_PREAMBLE = """You are a helpful AI assistant.""" | |
class Config: | |
device: str = "cpu" # Force CPU for ZeroGPU environment | |
batch_size: int = 32 | |
model_name: str = "aya-expanse-32B" | |
# API Keys from environment | |
groq_api_key: str = os.getenv("GROQ_API_KEY") | |
chat_cohere_api_key: str = os.getenv("CHAT_COHERE_API_KEY") | |
# Neets API key will be set via user input | |
neets_ai_api_key: Optional[str] = None | |
def set_neets_key(self, key: str): | |
"""Update Neets API key.""" | |
self.neets_ai_api_key = key | |
config = Config() | |
def get_clients(): | |
return { | |
'chat': cohere.Client( | |
api_key=config.chat_cohere_api_key, | |
client_name="c4ai-aya-expanse-chat" | |
), | |
'groq': Groq(api_key=config.groq_api_key) | |
} | |
clients = get_clients() | |
def load_lid_model(): | |
"""Load and cache the language identification model.""" | |
try: | |
lid_model_path = hf_hub_download( | |
repo_id="facebook/fasttext-language-identification", | |
filename="model.bin" | |
) | |
return fasttext.load_model(lid_model_path) | |
except Exception as e: | |
print(f"Error loading language model: {e}") | |
return None | |
def predict_language(text: str) -> str: | |
"""Predict language of input text using FastText model.""" | |
if not text: | |
return "eng_Latn" # Default to English | |
try: | |
text = re.sub("\n", " ", text) | |
model = load_lid_model() | |
if model is None: | |
return "eng_Latn" # Default if model fails to load | |
# Use predict method and handle numpy array safely | |
prediction = model.predict(text) | |
label = prediction[0][0] if prediction and prediction[0] else "__label__eng_Latn" | |
return label[len("__label__"):] | |
except Exception as e: | |
print(f"Language prediction error: {e}") | |
return "eng_Latn" # Default on error | |
def clean_text(text: str, remove_bullets: bool = False, remove_newline: bool = False) -> str: | |
"""Clean text by removing formatting and optional elements.""" | |
if not text: | |
return "" | |
text = re.sub(r"\*\*", "", text) | |
if remove_bullets: | |
text = re.sub(r"^- ", "", text, flags=re.MULTILINE) | |
if remove_newline: | |
text = re.sub(r"\n", " ", text) | |
return text.strip() | |
class ConversationManager: | |
"""Manages the entire conversation flow including voice, text, and memory.""" | |
def __init__(self): | |
self.chat_client = clients['chat'] | |
def check_neets_key(self) -> bool: | |
"""Check if Neets API key is set.""" | |
return bool(config.neets_ai_api_key) | |
def transcribe_audio(self, audio_file: str) -> Tuple[str, str]: | |
"""Transcribe audio to text and detect language.""" | |
if not audio_file: | |
return "", "eng_Latn" | |
try: | |
# Transcribe using Whisper | |
with open(audio_file, "rb") as f: | |
transcription = clients['groq'].audio.transcriptions.create( | |
file=(audio_file, f.read()), | |
model="whisper-large-v3-turbo", | |
response_format="json", | |
temperature=0.0 | |
) | |
text = transcription.text | |
lang_code = predict_language(text) | |
return text, lang_code | |
except Exception as e: | |
print(f"Transcription error: {e}") | |
return "", "eng_Latn" | |
def generate_response( | |
self, | |
user_input: str, | |
chat_history: List[dict], | |
conversation_id: str = None | |
) -> Tuple[List[dict], str, str]: | |
"""Generate assistant's response based on user input and conversation history.""" | |
if not conversation_id: | |
conversation_id = str(uuid.uuid4()) | |
if not chat_history: | |
chat_history = [] | |
try: | |
# Format history for the model | |
formatted_history = [] | |
for msg in chat_history: | |
formatted_history.extend([msg["content"]]) | |
# Generate response | |
stream = self.chat_client.chat_stream( | |
message=user_input, | |
preamble=CHAT_PREAMBLE, | |
conversation_id=conversation_id, | |
model=config.model_name, | |
temperature=0.3, | |
chat_history=formatted_history | |
) | |
# Collect response | |
response = "" | |
for event in stream: | |
if event.event_type == "text-generation": | |
response += event.text | |
# Update chat history | |
chat_history.extend([ | |
{"role": "user", "content": user_input}, | |
{"role": "assistant", "content": response} | |
]) | |
yield chat_history, response, conversation_id | |
return chat_history, response, conversation_id | |
except Exception as e: | |
print(f"Response generation error: {e}") | |
error_msg = "I apologize, but I encountered an error generating a response. Please try again." | |
chat_history.extend([ | |
{"role": "user", "content": user_input}, | |
{"role": "assistant", "content": error_msg} | |
]) | |
return chat_history, error_msg, conversation_id | |
def text_to_speech(self, text: str, lang_code: str) -> str: | |
"""Convert text to speech using Neets.ai.""" | |
if not text: | |
return None | |
if not self.check_neets_key(): | |
raise ValueError("Neets API key not set. Please enter your API key.") | |
try: | |
# Get language mapping for Neets.ai | |
neets_lang_id = NEETS_AI_LANGID_MAP.get(lang_code, "en") | |
neets_vits_voice_id = f"vits-{neets_lang_id}" | |
response = requests.post( | |
url="https://api.neets.ai/v1/tts", | |
headers={ | |
"Content-Type": "application/json", | |
"X-API-Key": config.neets_ai_api_key | |
}, | |
json={ | |
"text": text, | |
"voice_id": neets_vits_voice_id, | |
"params": {"model": "vits"} | |
} | |
) | |
if response.status_code != 200: | |
raise ValueError(f"Neets API error: {response.text}") | |
audio_path = f"neets_response_{uuid.uuid4()}.mp3" | |
with open(audio_path, "wb") as f: | |
f.write(response.content) | |
return audio_path | |
except Exception as e: | |
print(f"Text-to-speech error: {e}") | |
raise ValueError(f"Failed to generate speech: {str(e)}") | |
def clear_conversation(self) -> Tuple[List, str]: | |
"""Clear the conversation history.""" | |
return [], str(uuid.uuid4()) | |
def create_gradio_interface(): | |
"""Create the Gradio interface for the conversational AI system.""" | |
theme = gr.themes.Base( | |
primary_hue=gr.themes.colors.teal, | |
secondary_hue=gr.themes.colors.blue, | |
neutral_hue=gr.themes.colors.gray, | |
text_size=gr.themes.sizes.text_lg, | |
).set( | |
button_primary_background_fill="#114A56", | |
button_primary_background_fill_hover="#114A56", | |
block_title_text_weight="600", | |
block_label_text_weight="600", | |
block_label_text_size="*text_md", | |
) | |
conversation_manager = ConversationManager() | |
with gr.Blocks(theme=theme, analytics_enabled=False) as demo: | |
# Header | |
with gr.Row(): | |
gr.Markdown(""" | |
# Multilingual Voice Chat Assistant | |
Have a natural conversation with Aya using voice or text in any of 23 supported languages. | |
""") | |
# API Key input | |
with gr.Row(): | |
with gr.Column(scale=1): | |
neets_key = gr.Textbox( | |
type="password", | |
label="Enter your Neets.ai API Key", | |
placeholder="Enter API key here...", | |
show_label=True | |
) | |
api_status = gr.Markdown("API Key Status: Not Set") | |
def update_api_key(key): | |
if not key: | |
return "API Key Status: Not Set" | |
config.set_neets_key(key) | |
return "API Key Status: Set ✓" | |
neets_key.change( | |
update_api_key, | |
inputs=[neets_key], | |
outputs=[api_status] | |
) | |
# State management | |
conversation_id = gr.State("") | |
current_language = gr.State("eng_Latn") | |
with gr.Row(): | |
# Chat interface | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot( | |
label="Conversation", | |
height=400, | |
type="messages", | |
) | |
# Input options | |
with gr.Row(): | |
text_input = gr.Textbox( | |
placeholder="Type your message or use voice input...", | |
label="Text Input", | |
lines=2 | |
) | |
audio_input = gr.Audio( | |
sources=["microphone"], | |
type="filepath", | |
label="Voice Input" | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Send Message", variant="primary") | |
clear_btn = gr.Button("Clear Conversation") | |
# Audio output and info | |
with gr.Column(scale=1): | |
response_audio = gr.Audio( | |
label="Assistant's Voice Response", | |
type="filepath" | |
) | |
detected_language = gr.Markdown( | |
"Detected Language: English", | |
label="Language Info" | |
) | |
def process_input( | |
input_text: str, | |
input_audio: str, | |
history: List[dict], | |
conv_id: str, | |
neets_key: str | |
): | |
if not neets_key: | |
raise gr.Error("Please enter your Neets.ai API key first") | |
# Determine input source | |
if input_audio: | |
user_text, lang_code = conversation_manager.transcribe_audio(input_audio) | |
else: | |
user_text = input_text | |
lang_code = predict_language(user_text) | |
# Get language name | |
lang_name = LID_LANGUAGES.get(lang_code, "Unknown") | |
language_info = f"Detected Language: {lang_name}" | |
# Generate response | |
new_history, response, new_conv_id = conversation_manager.generate_response( | |
user_text, history, conv_id | |
) | |
try: | |
# Generate audio response | |
audio_path = conversation_manager.text_to_speech(response, lang_code) | |
except ValueError as e: | |
raise gr.Error(str(e)) | |
return new_history, new_conv_id, audio_path, language_info | |
# Connect event handlers | |
submit_btn.click( | |
process_input, | |
inputs=[ | |
text_input, | |
audio_input, | |
chatbot, | |
conversation_id, | |
neets_key | |
], | |
outputs=[ | |
chatbot, | |
conversation_id, | |
response_audio, | |
detected_language | |
] | |
) | |
# Also trigger on text input enter | |
text_input.submit( | |
process_input, | |
inputs=[ | |
text_input, | |
audio_input, | |
chatbot, | |
conversation_id, | |
neets_key | |
], | |
outputs=[ | |
chatbot, | |
conversation_id, | |
response_audio, | |
detected_language | |
] | |
) | |
# Clear conversation | |
clear_btn.click( | |
conversation_manager.clear_conversation, | |
outputs=[chatbot, conversation_id] | |
) | |
# Clear inputs after submission | |
submit_btn.click(lambda: "", None, text_input) | |
submit_btn.click(lambda: None, None, audio_input) | |
return demo | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.queue( | |
api_open=False, | |
max_size=20, | |
default_concurrency_limit=4 | |
).launch( | |
show_api=False, | |
allowed_paths=['/home/user/app'] | |
) |