Aya_For_Schools / app.py
jeremierostan's picture
Update app.py
89eb0d0 verified
from __future__ import annotations
import os
import re
import time
import uuid
import torch
import cohere
import secrets
import fasttext
import requests
import numpy as np # Added numpy import
from groq import Groq
from dataclasses import dataclass
from typing import Optional, List, Tuple, Any
from huggingface_hub import hf_hub_download
from functools import lru_cache
import gradio as gr
# Mapping of language codes to names and Neets voice IDs
LID_LANGUAGES = {
"eng_Latn": "English",
"spa_Latn": "Spanish",
"fra_Latn": "French",
# Add other languages as needed
}
NEETS_AI_LANGID_MAP = {
"eng_Latn": "en",
"spa_Latn": "es",
"fra_Latn": "fr",
# Add other mappings as needed
}
# Chat preamble
CHAT_PREAMBLE = """You are a helpful AI assistant."""
@dataclass
class Config:
device: str = "cpu" # Force CPU for ZeroGPU environment
batch_size: int = 32
model_name: str = "aya-expanse-32B"
# API Keys from environment
groq_api_key: str = os.getenv("GROQ_API_KEY")
chat_cohere_api_key: str = os.getenv("CHAT_COHERE_API_KEY")
# Neets API key will be set via user input
neets_ai_api_key: Optional[str] = None
def set_neets_key(self, key: str):
"""Update Neets API key."""
self.neets_ai_api_key = key
config = Config()
def get_clients():
return {
'chat': cohere.Client(
api_key=config.chat_cohere_api_key,
client_name="c4ai-aya-expanse-chat"
),
'groq': Groq(api_key=config.groq_api_key)
}
clients = get_clients()
@lru_cache(maxsize=1)
def load_lid_model():
"""Load and cache the language identification model."""
try:
lid_model_path = hf_hub_download(
repo_id="facebook/fasttext-language-identification",
filename="model.bin"
)
return fasttext.load_model(lid_model_path)
except Exception as e:
print(f"Error loading language model: {e}")
return None
def predict_language(text: str) -> str:
"""Predict language of input text using FastText model."""
if not text:
return "eng_Latn" # Default to English
try:
text = re.sub("\n", " ", text)
model = load_lid_model()
if model is None:
return "eng_Latn" # Default if model fails to load
# Use predict method and handle numpy array safely
prediction = model.predict(text)
label = prediction[0][0] if prediction and prediction[0] else "__label__eng_Latn"
return label[len("__label__"):]
except Exception as e:
print(f"Language prediction error: {e}")
return "eng_Latn" # Default on error
def clean_text(text: str, remove_bullets: bool = False, remove_newline: bool = False) -> str:
"""Clean text by removing formatting and optional elements."""
if not text:
return ""
text = re.sub(r"\*\*", "", text)
if remove_bullets:
text = re.sub(r"^- ", "", text, flags=re.MULTILINE)
if remove_newline:
text = re.sub(r"\n", " ", text)
return text.strip()
class ConversationManager:
"""Manages the entire conversation flow including voice, text, and memory."""
def __init__(self):
self.chat_client = clients['chat']
def check_neets_key(self) -> bool:
"""Check if Neets API key is set."""
return bool(config.neets_ai_api_key)
def transcribe_audio(self, audio_file: str) -> Tuple[str, str]:
"""Transcribe audio to text and detect language."""
if not audio_file:
return "", "eng_Latn"
try:
# Transcribe using Whisper
with open(audio_file, "rb") as f:
transcription = clients['groq'].audio.transcriptions.create(
file=(audio_file, f.read()),
model="whisper-large-v3-turbo",
response_format="json",
temperature=0.0
)
text = transcription.text
lang_code = predict_language(text)
return text, lang_code
except Exception as e:
print(f"Transcription error: {e}")
return "", "eng_Latn"
def generate_response(
self,
user_input: str,
chat_history: List[dict],
conversation_id: str = None
) -> Tuple[List[dict], str, str]:
"""Generate assistant's response based on user input and conversation history."""
if not conversation_id:
conversation_id = str(uuid.uuid4())
if not chat_history:
chat_history = []
try:
# Format history for the model
formatted_history = []
for msg in chat_history:
formatted_history.extend([msg["content"]])
# Generate response
stream = self.chat_client.chat_stream(
message=user_input,
preamble=CHAT_PREAMBLE,
conversation_id=conversation_id,
model=config.model_name,
temperature=0.3,
chat_history=formatted_history
)
# Collect response
response = ""
for event in stream:
if event.event_type == "text-generation":
response += event.text
# Update chat history
chat_history.extend([
{"role": "user", "content": user_input},
{"role": "assistant", "content": response}
])
yield chat_history, response, conversation_id
return chat_history, response, conversation_id
except Exception as e:
print(f"Response generation error: {e}")
error_msg = "I apologize, but I encountered an error generating a response. Please try again."
chat_history.extend([
{"role": "user", "content": user_input},
{"role": "assistant", "content": error_msg}
])
return chat_history, error_msg, conversation_id
def text_to_speech(self, text: str, lang_code: str) -> str:
"""Convert text to speech using Neets.ai."""
if not text:
return None
if not self.check_neets_key():
raise ValueError("Neets API key not set. Please enter your API key.")
try:
# Get language mapping for Neets.ai
neets_lang_id = NEETS_AI_LANGID_MAP.get(lang_code, "en")
neets_vits_voice_id = f"vits-{neets_lang_id}"
response = requests.post(
url="https://api.neets.ai/v1/tts",
headers={
"Content-Type": "application/json",
"X-API-Key": config.neets_ai_api_key
},
json={
"text": text,
"voice_id": neets_vits_voice_id,
"params": {"model": "vits"}
}
)
if response.status_code != 200:
raise ValueError(f"Neets API error: {response.text}")
audio_path = f"neets_response_{uuid.uuid4()}.mp3"
with open(audio_path, "wb") as f:
f.write(response.content)
return audio_path
except Exception as e:
print(f"Text-to-speech error: {e}")
raise ValueError(f"Failed to generate speech: {str(e)}")
def clear_conversation(self) -> Tuple[List, str]:
"""Clear the conversation history."""
return [], str(uuid.uuid4())
def create_gradio_interface():
"""Create the Gradio interface for the conversational AI system."""
theme = gr.themes.Base(
primary_hue=gr.themes.colors.teal,
secondary_hue=gr.themes.colors.blue,
neutral_hue=gr.themes.colors.gray,
text_size=gr.themes.sizes.text_lg,
).set(
button_primary_background_fill="#114A56",
button_primary_background_fill_hover="#114A56",
block_title_text_weight="600",
block_label_text_weight="600",
block_label_text_size="*text_md",
)
conversation_manager = ConversationManager()
with gr.Blocks(theme=theme, analytics_enabled=False) as demo:
# Header
with gr.Row():
gr.Markdown("""
# Multilingual Voice Chat Assistant
Have a natural conversation with Aya using voice or text in any of 23 supported languages.
""")
# API Key input
with gr.Row():
with gr.Column(scale=1):
neets_key = gr.Textbox(
type="password",
label="Enter your Neets.ai API Key",
placeholder="Enter API key here...",
show_label=True
)
api_status = gr.Markdown("API Key Status: Not Set")
def update_api_key(key):
if not key:
return "API Key Status: Not Set"
config.set_neets_key(key)
return "API Key Status: Set ✓"
neets_key.change(
update_api_key,
inputs=[neets_key],
outputs=[api_status]
)
# State management
conversation_id = gr.State("")
current_language = gr.State("eng_Latn")
with gr.Row():
# Chat interface
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Conversation",
height=400,
type="messages",
)
# Input options
with gr.Row():
text_input = gr.Textbox(
placeholder="Type your message or use voice input...",
label="Text Input",
lines=2
)
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label="Voice Input"
)
with gr.Row():
submit_btn = gr.Button("Send Message", variant="primary")
clear_btn = gr.Button("Clear Conversation")
# Audio output and info
with gr.Column(scale=1):
response_audio = gr.Audio(
label="Assistant's Voice Response",
type="filepath"
)
detected_language = gr.Markdown(
"Detected Language: English",
label="Language Info"
)
def process_input(
input_text: str,
input_audio: str,
history: List[dict],
conv_id: str,
neets_key: str
):
if not neets_key:
raise gr.Error("Please enter your Neets.ai API key first")
# Determine input source
if input_audio:
user_text, lang_code = conversation_manager.transcribe_audio(input_audio)
else:
user_text = input_text
lang_code = predict_language(user_text)
# Get language name
lang_name = LID_LANGUAGES.get(lang_code, "Unknown")
language_info = f"Detected Language: {lang_name}"
# Generate response
new_history, response, new_conv_id = conversation_manager.generate_response(
user_text, history, conv_id
)
try:
# Generate audio response
audio_path = conversation_manager.text_to_speech(response, lang_code)
except ValueError as e:
raise gr.Error(str(e))
return new_history, new_conv_id, audio_path, language_info
# Connect event handlers
submit_btn.click(
process_input,
inputs=[
text_input,
audio_input,
chatbot,
conversation_id,
neets_key
],
outputs=[
chatbot,
conversation_id,
response_audio,
detected_language
]
)
# Also trigger on text input enter
text_input.submit(
process_input,
inputs=[
text_input,
audio_input,
chatbot,
conversation_id,
neets_key
],
outputs=[
chatbot,
conversation_id,
response_audio,
detected_language
]
)
# Clear conversation
clear_btn.click(
conversation_manager.clear_conversation,
outputs=[chatbot, conversation_id]
)
# Clear inputs after submission
submit_btn.click(lambda: "", None, text_input)
submit_btn.click(lambda: None, None, audio_input)
return demo
if __name__ == "__main__":
demo = create_gradio_interface()
demo.queue(
api_open=False,
max_size=20,
default_concurrency_limit=4
).launch(
show_api=False,
allowed_paths=['/home/user/app']
)