Spaces:
Running
Running
| import gradio as gr | |
| import time | |
| import torch | |
| import os | |
| import gc | |
| import psutil | |
| from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, VitsModel, VitsTokenizer | |
| import soundfile as sf | |
| import librosa | |
| import tempfile | |
| import google.generativeai as genai | |
| from dotenv import load_dotenv | |
| # Try to load .env file as fallback (for local development) | |
| # HF Spaces will use secrets directly, so this won't override them | |
| load_dotenv() | |
| # Set environment variables for optimization | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" # Use tmp for HF Spaces | |
| os.environ["HF_HOME"] = "/tmp/huggingface" # Cache location | |
| def get_memory_usage(): | |
| """Get current memory usage in MB""" | |
| process = psutil.Process(os.getpid()) | |
| return process.memory_info().rss / 1024 / 1024 | |
| def log_memory(context=""): | |
| """Log current memory usage""" | |
| memory_mb = get_memory_usage() | |
| print(f"Memory usage {context}: {memory_mb:.1f} MB") | |
| class LatinConversationBot: | |
| def __init__(self): | |
| log_memory("at initialization start") | |
| # Force CPU-only to reduce memory usage on Hugging Face Spaces | |
| self.device = "cpu" | |
| self.message_audio = {} | |
| self.message_texts = {} | |
| # Initialize Gemini using HF Spaces secret or .env fallback | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| # More helpful error message for both HF Spaces and local dev | |
| raise ValueError( | |
| "GEMINI_API_KEY not found!\n" | |
| "For Hugging Face Spaces:\n" | |
| " 1. Go to your Space settings\n" | |
| " 2. Click on 'Repository secrets'\n" | |
| " 3. Add 'GEMINI_API_KEY' with your API key\n" | |
| "For Local Development:\n" | |
| " 1. Create a .env file in the project root\n" | |
| " 2. Add: GEMINI_API_KEY=your_api_key_here" | |
| ) | |
| genai.configure(api_key=api_key) | |
| self.gemini_model = genai.GenerativeModel('gemini-flash-latest') | |
| # Model containers | |
| self.asr_processor = None | |
| self.asr_model = None | |
| self.tts_model = None | |
| self.tts_tokenizer = None | |
| self.models_loaded = {"asr": False, "tts": False} | |
| print(f"Bot initialized on device: {self.device}") | |
| # Pre-load models at startup for faster response | |
| try: | |
| print("π Starting model pre-loading...") | |
| self._preload_models() | |
| print("β All models loaded successfully!") | |
| except Exception as e: | |
| print(f"β οΈ Model pre-loading failed: {e}") | |
| print("Models will be loaded on-demand") | |
| log_memory("after initialization") | |
| def _preload_models(self): | |
| """Pre-load models at startup but manage memory efficiently""" | |
| try: | |
| # Load ASR first with optimizations | |
| print("π₯ Loading ASR models...") | |
| self.asr_processor = AutoProcessor.from_pretrained( | |
| "ken-z/latin_whisper-small", | |
| cache_dir="/tmp/transformers_cache", | |
| local_files_only=False | |
| ) | |
| self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| "ken-z/latin_whisper-small", | |
| torch_dtype=torch.float32, | |
| cache_dir="/tmp/transformers_cache", | |
| low_cpu_mem_usage=True, # Optimize memory usage | |
| local_files_only=False | |
| ).to(self.device) | |
| self.models_loaded["asr"] = True | |
| log_memory("after ASR loading") | |
| # Load TTS with optimizations | |
| print("π΅ Loading TTS models...") | |
| self.tts_tokenizer = VitsTokenizer.from_pretrained( | |
| "Ken-Z/latin_SpeechT5", | |
| cache_dir="/tmp/transformers_cache", | |
| local_files_only=False | |
| ) | |
| self.tts_model = VitsModel.from_pretrained( | |
| "Ken-Z/latin_SpeechT5", | |
| torch_dtype=torch.float32, | |
| cache_dir="/tmp/transformers_cache", | |
| low_cpu_mem_usage=True, # Optimize memory usage | |
| local_files_only=False | |
| ).to(self.device) | |
| self.models_loaded["tts"] = True | |
| log_memory("after TTS loading") | |
| except Exception as e: | |
| print(f"Error in model loading: {e}") | |
| # Fallback to lazy loading | |
| self.models_loaded = {"asr": False, "tts": False} | |
| raise e | |
| def _ensure_asr_loaded(self): | |
| """Ensure ASR models are loaded""" | |
| if not self.models_loaded["asr"]: | |
| print("Loading ASR models on-demand...") | |
| self.asr_processor = AutoProcessor.from_pretrained("ken-z/latin_whisper-small") | |
| self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| "ken-z/latin_whisper-small", | |
| torch_dtype=torch.float32 | |
| ).to(self.device) | |
| self.models_loaded["asr"] = True | |
| def _ensure_tts_loaded(self): | |
| """Ensure TTS models are loaded""" | |
| if not self.models_loaded["tts"]: | |
| print("Loading TTS models on-demand...") | |
| self.tts_tokenizer = VitsTokenizer.from_pretrained("Ken-Z/latin_SpeechT5") | |
| self.tts_model = VitsModel.from_pretrained( | |
| "Ken-Z/latin_SpeechT5", | |
| torch_dtype=torch.float32 | |
| ).to(self.device) | |
| self.models_loaded["tts"] = True | |
| def _cleanup_models(self): | |
| """Free up memory by clearing unused models""" | |
| log_memory("before cleanup") | |
| if self.asr_model is not None: | |
| del self.asr_model | |
| self.asr_model = None | |
| self.models_loaded["asr"] = False | |
| if self.asr_processor is not None: | |
| del self.asr_processor | |
| self.asr_processor = None | |
| if self.tts_model is not None: | |
| del self.tts_model | |
| self.tts_model = None | |
| self.models_loaded["tts"] = False | |
| if self.tts_tokenizer is not None: | |
| del self.tts_tokenizer | |
| self.tts_tokenizer = None | |
| gc.collect() | |
| log_memory("after cleanup") | |
| print("Models cleaned up from memory") | |
| def transcribe_audio(self, audio_path): | |
| try: | |
| # Ensure ASR models are loaded | |
| self._ensure_asr_loaded() | |
| audio, _ = librosa.load(audio_path, sr=16000) | |
| input_features = self.asr_processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(self.device) | |
| with torch.no_grad(): | |
| predicted_ids = self.asr_model.generate(input_features) | |
| result = self.asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip() | |
| # Clean up tensors but keep models loaded | |
| del input_features, predicted_ids | |
| gc.collect() | |
| return result | |
| except Exception as e: | |
| print(f"ASR Error: {str(e)}") | |
| return f"Error: {str(e)}" | |
| def _call_gemini(self, prompt): | |
| try: | |
| return self.gemini_model.generate_content(prompt).text.strip() | |
| except Exception as e: | |
| print(f"Gemini API error: {e}") | |
| return "Error: Gemini API not available" | |
| def generate_response(self, text): | |
| prompt = f"""You are a Latin conversation bot. Respond ONLY in Latin, keep responses to 1-2 sentences, use proper Classical Latin grammar with proper diacritics, and be conversational. | |
| Examples: "Salve" β "Salve! Quid agis hodie?", "Hello" β "Salve! Latine loquere, quaeso!" | |
| User: {text} | |
| Response:""" | |
| return self._call_gemini(prompt) | |
| def improve_latin_grammar(self, text): | |
| prompt = f"""Fix Latin grammar, diacritics, and word order. Format: | |
| CORRECTED: [corrected text] | |
| EXPLANATION: [brief explanation of fixes only] | |
| Text: {text}""" | |
| response = self._call_gemini(prompt) | |
| # Parse response | |
| corrected = explanation = "" | |
| for line in response.split('\n'): | |
| if line.startswith("CORRECTED:"): | |
| corrected = line[10:].strip() | |
| elif line.startswith("EXPLANATION:"): | |
| explanation = line[12:].strip() | |
| return { | |
| "corrected": corrected or text, | |
| "explanation": explanation or "No explanation provided." | |
| } | |
| def translate_latin(self, text, target_language): | |
| prompt = f"""Translate this Latin text to {target_language}. Return ONLY the translation, no explanations. | |
| Latin text: {text} | |
| {target_language} translation:""" | |
| return self._call_gemini(prompt) | |
| def synthesize_speech(self, text): | |
| try: | |
| # Ensure TTS models are loaded | |
| self._ensure_tts_loaded() | |
| inputs = self.tts_tokenizer(text, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| speech = self.tts_model(**inputs).waveform.squeeze().cpu().numpy() | |
| # Clean up tensors but keep models loaded | |
| del inputs | |
| gc.collect() | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| sf.write(tmp_file.name, speech, samplerate=16000) | |
| return tmp_file.name | |
| except Exception as e: | |
| print(f"TTS error: {e}") | |
| return None | |
| bot_instance = LatinConversationBot() | |
| def add_message(history, message): | |
| for file_info in message["files"]: | |
| file_path = file_info.path if hasattr(file_info, 'path') else file_info | |
| if file_path.endswith(('.wav', '.mp3', '.m4a', '.ogg', '.flac')): | |
| transcription = bot_instance.transcribe_audio(file_path) | |
| history.append({"role": "user", "content": f"π€ {transcription}"}) | |
| if message["text"] and message["text"].strip(): | |
| history.append({"role": "user", "content": message["text"]}) | |
| return history, gr.MultimodalTextbox(value=None, interactive=False) | |
| def get_dropdown_choices(history): | |
| """Generate all dropdown choices at once""" | |
| replay_choices = [(f"π {text[:30]}{'...' if len(text) > 30 else ''}", msg_id) | |
| for msg_id, text in bot_instance.message_texts.items()] | |
| improve_choices = [(f"Message {i+1}: {msg['content'].replace('π€ ', '')[:50]}{'...' if len(msg['content'].replace('π€ ', '')) > 50 else ''}", i) | |
| for i, msg in enumerate(history) if msg["role"] == "user"] | |
| translate_choices = [(f"Bot {i+1}: {msg['content'][:50]}{'...' if len(msg['content']) > 50 else ''}", i) | |
| for i, msg in enumerate(history) if msg["role"] == "assistant"] | |
| return replay_choices, improve_choices, translate_choices | |
| def bot(history): | |
| if not history: | |
| return history, None, gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[]) | |
| last_message = history[-1]["content"] | |
| user_text = last_message.replace("π€ ", "") if last_message.startswith("π€ ") else last_message | |
| response_text = bot_instance.generate_response(user_text) | |
| message_id = f"msg_{len(history)}_{int(time.time())}" | |
| history.append({"role": "assistant", "content": response_text}) | |
| audio_file = bot_instance.synthesize_speech(response_text) | |
| if audio_file: | |
| bot_instance.message_audio[message_id] = audio_file | |
| bot_instance.message_texts[message_id] = response_text | |
| replay_choices, improve_choices, translate_choices = get_dropdown_choices(history) | |
| return history, audio_file, gr.Dropdown(choices=replay_choices), gr.Dropdown(choices=improve_choices), gr.Dropdown(choices=translate_choices) | |
| def improve_message_grammar(history, message_index): | |
| if not history or message_index < 0 or message_index >= len(history) or history[message_index]["role"] != "user": | |
| return history, "" | |
| original_text = history[message_index]["content"] | |
| prefix = "π€ " if original_text.startswith("π€ ") else "" | |
| text_to_improve = original_text.replace("π€ ", "") | |
| improvement_result = bot_instance.improve_latin_grammar(text_to_improve) | |
| corrected_text = improvement_result["corrected"] | |
| explanation = improvement_result["explanation"] | |
| if corrected_text and corrected_text != text_to_improve: | |
| history[message_index]["content"] = f"{prefix}{corrected_text} β¨" | |
| return history, explanation | |
| def clear_all_data(): | |
| bot_instance.message_audio.clear() | |
| bot_instance.message_texts.clear() | |
| # Also clean up models to free memory | |
| bot_instance._cleanup_models() | |
| print("All data and models cleared from memory") | |
| return [], None, gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[]) | |
| # Initialize the bot instance early | |
| print("π Initializing Latin Conversation Bot...") | |
| bot_instance = LatinConversationBot() | |
| with gr.Blocks(title="ποΈ Latin Conversation Bot", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # ποΈ Latin Conversation Bot | |
| Speak or type in Latin for AI-powered conversations with speech synthesis and grammar improvement! | |
| """) | |
| chatbot = gr.Chatbot(type="messages", height=400, show_label=False) | |
| chat_input = gr.MultimodalTextbox( | |
| interactive=True, file_types=["audio"], placeholder="π€ Record or type in Latin...", | |
| show_label=False, sources=["microphone", "upload"] | |
| ) | |
| with gr.Row(): | |
| audio_output = gr.Audio(label="π Bot Response", autoplay=True, scale=2) | |
| replay_dropdown = gr.Dropdown(label="π Replay Message", choices=[], scale=1) | |
| with gr.Row(): | |
| improve_dropdown = gr.Dropdown(label="β¨ Select Message to Improve", choices=[], scale=2) | |
| improve_btn = gr.Button("β¨ Improve Grammar", size="sm", variant="secondary", scale=1) | |
| grammar_explanation = gr.Textbox(label="π Grammar Explanation", interactive=False, visible=False) | |
| with gr.Row(): | |
| translate_dropdown = gr.Dropdown(label="π Select Bot Message to Translate", choices=[], scale=2) | |
| language_dropdown = gr.Dropdown( | |
| label="Target Language", | |
| choices=["English", "Spanish", "French", "German", "Italian", "Portuguese", "Chinese", "Japanese"], | |
| value="English", | |
| scale=1 | |
| ) | |
| translate_btn = gr.Button("π Translate", size="sm", variant="secondary", scale=1) | |
| translation_output = gr.Textbox(label="π Translation", interactive=False, visible=False) | |
| clear_btn = gr.Button("ποΈ Clear", size="sm") | |
| # Event handlers | |
| chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) | |
| bot_msg = chat_msg.then(bot, chatbot, [chatbot, audio_output, replay_dropdown, improve_dropdown, translate_dropdown]) | |
| bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]) | |
| replay_dropdown.change( | |
| lambda msg_id: bot_instance.message_audio.get(msg_id) if msg_id else None, | |
| inputs=[replay_dropdown], outputs=[audio_output] | |
| ) | |
| clear_btn.click(clear_all_data, outputs=[chatbot, audio_output, replay_dropdown, improve_dropdown, translate_dropdown]) | |
| def improve_selected_message(history, selected_index): | |
| if selected_index is None: | |
| _, improve_choices, _ = get_dropdown_choices(history) | |
| return history, gr.Dropdown(choices=improve_choices), gr.Textbox(visible=False) | |
| improved_history, explanation = improve_message_grammar(history, selected_index) | |
| _, improve_choices, _ = get_dropdown_choices(improved_history) | |
| show_explanation = explanation and explanation != "No corrections needed." | |
| return improved_history, gr.Dropdown(choices=improve_choices), gr.Textbox(value=explanation if show_explanation else "", visible=show_explanation) | |
| def translate_selected_message(history, selected_index, target_language): | |
| if selected_index is None or not history or selected_index >= len(history) or history[selected_index]["role"] != "assistant": | |
| return gr.Textbox(visible=False) | |
| latin_text = history[selected_index]["content"] | |
| translation = bot_instance.translate_latin(latin_text, target_language) | |
| return gr.Textbox(value=f"Original: {latin_text}\n\n{target_language}: {translation}", visible=True) | |
| improve_btn.click(improve_selected_message, [chatbot, improve_dropdown], [chatbot, improve_dropdown, grammar_explanation]) | |
| translate_btn.click(translate_selected_message, [chatbot, translate_dropdown, language_dropdown], [translation_output]) | |
| if __name__ == "__main__": | |
| # Launch with optimized settings for HF Spaces | |
| demo.launch( | |
| server_port=7860, # Standard HF Spaces port | |
| share=False, | |
| show_error=True, | |
| quiet=False # Show startup logs | |
| ) |