import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import numpy as np import threading import logging import os import gc import time from typing import Generator, List, Dict, Any, Optional import warnings # Suppress warnings for cleaner output warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) # Global variable for generation control generation_stopped = threading.Event() # Configuration class Config: """Application configuration""" # Model settings BASE_MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" ADAPTER_PATH = "echarif/lora_adapter_llama3.2_1B" # Generation settings MAX_TOKENS_DEFAULT = 512 MAX_TOKENS_LIMIT = 1024 TEMPERATURE_DEFAULT = 0.7 TOP_P_DEFAULT = 0.95 # System settings MAX_HISTORY_LENGTH = 50 STREAM_TIMEOUT = 30 # HuggingFace Spaces settings HF_SPACE = os.getenv("SPACE_ID") is not None # --- MODIFIED: Added HF_TOKEN for gated models --- HF_TOKEN = os.getenv("HF_TOKEN") DEBUG_MODE = os.getenv("DEBUG", "false").lower() == "true" # Device and performance DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 # Error messages # --- REMOVED: OFF_TOPIC_MSG as classifier is gone --- ERROR_MSG = "Une erreur s'est produite. Veuillez réessayer." LOADING_MSG = "Chargement en cours..." # Configure logging log_level = logging.WARNING if not Config.DEBUG_MODE else logging.INFO logging.basicConfig( level=log_level, format='%(levelname)s - %(message)s' if Config.DEBUG_MODE else '%(message)s', handlers=[logging.StreamHandler()] ) logger = logging.getLogger(__name__) class ModelManager: """Manages model loading and inference""" def __init__(self): self.chat_model: Optional[Any] = None self.chat_tokenizer: Optional[Any] = None # --- REMOVED: Classifier model attributes --- self.models_loaded = False self.load_start_time = None def load_models(self) -> bool: """Load all models with proper error handling""" self.load_start_time = time.time() # --- MODIFIED: Added check for HF_TOKEN --- if Config.HF_SPACE and not Config.HF_TOKEN: logger.error("HF_TOKEN secret not found. This is required for gated models on Spaces.") return False try: # Load chatbot model if not self._load_chatbot_model(): return False # --- REMOVED: Classifier model loading logic --- load_time = time.time() - self.load_start_time if Config.DEBUG_MODE: logger.info(f"Chatbot model loaded successfully in {load_time:.2f}s") self.models_loaded = True return True except Exception as e: logger.error(f"Critical error during model loading: {e}") return False def _load_chatbot_model(self) -> bool: """Load chatbot model with LoRA adapter""" try: # Load base model self.chat_model = AutoModelForCausalLM.from_pretrained( Config.BASE_MODEL_ID, torch_dtype=Config.TORCH_DTYPE, device_map="auto" if Config.DEVICE == "cuda" else None, trust_remote_code=True, low_cpu_mem_usage=True, # --- MODIFIED: Added token for authentication --- token=Config.HF_TOKEN, ) # Load tokenizer self.chat_tokenizer = AutoTokenizer.from_pretrained( Config.BASE_MODEL_ID, trust_remote_code=True, # --- MODIFIED: Added token for authentication --- token=Config.HF_TOKEN, ) # Load LoRA adapter self.chat_model.load_adapter(Config.ADAPTER_PATH) # Configure tokenizer if self.chat_tokenizer.pad_token is None: self.chat_tokenizer.pad_token = self.chat_tokenizer.eos_token # Move to device if needed if Config.DEVICE == "cuda" and hasattr(self.chat_model, 'to'): self.chat_model = self.chat_model.to(Config.DEVICE) return True except Exception as e: logger.error(f"Failed to load chatbot model: {e}") return False # --- REMOVED: _load_classifier_model method --- # --- REMOVED: classify_input method --- def generate_response_stream( self, message: str, max_tokens: int = Config.MAX_TOKENS_DEFAULT, temperature: float = Config.TEMPERATURE_DEFAULT, top_p: float = Config.TOP_P_DEFAULT ) -> Generator[str, None, None]: """Generate streaming response - optimized for speed with stop functionality""" if not self.models_loaded or not self.chat_model or not self.chat_tokenizer: yield Config.ERROR_MSG return # Reset stop flag generation_stopped.clear() # Validate parameters max_tokens = min(max_tokens, Config.MAX_TOKENS_LIMIT) temperature = max(0.1, min(temperature, 1.0)) top_p = max(0.1, min(top_p, 1.0)) # --- REMOVED: Classification check --- try: # Prepare input messages = [{"role": "user", "content": message}] text = self.chat_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize inputs = self.chat_tokenizer([text], return_tensors='pt').to(Config.DEVICE) # Setup streamer streamer = TextIteratorStreamer( self.chat_tokenizer, timeout=10, skip_prompt=True, skip_special_tokens=True ) # Generation parameters generation_kwargs = { "input_ids": inputs.input_ids, "max_new_tokens": max_tokens, "temperature": temperature, "top_p": top_p, "do_sample": True, "streamer": streamer, "pad_token_id": self.chat_tokenizer.eos_token_id, "eos_token_id": self.chat_tokenizer.eos_token_id, "use_cache": True, } # Start generation thread generation_thread = threading.Thread( target=self.chat_model.generate, kwargs=generation_kwargs ) generation_thread.daemon = True generation_thread.start() # Stream response generated_text = "" for new_text in streamer: if generation_stopped.is_set(): break if new_text: generated_text += new_text yield generated_text generation_thread.join(timeout=0.5) except Exception as e: if Config.DEBUG_MODE: logger.error(f"Generation error: {e}") yield Config.ERROR_MSG finally: if Config.DEVICE == "cuda": torch.cuda.empty_cache() def cleanup(self): """Fast resource cleanup""" try: if hasattr(self, 'chat_model'): del self.chat_model if hasattr(self, 'chat_tokenizer'): del self.chat_tokenizer # --- REMOVED: Classifier cleanup --- if Config.DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() except Exception: pass # Initialize model manager model_manager = ModelManager() def stop_generation(): """Stop the current generation""" generation_stopped.set() def chat_interface( message: str, history: List[Dict[str, str]], max_tokens: int, temperature: float, top_p: float ): """Main chat interface with streaming and stop functionality""" if not message or not message.strip(): yield "", history, gr.update(value="Envoyer", interactive=True) return if len(history) > Config.MAX_HISTORY_LENGTH: history = history[-Config.MAX_HISTORY_LENGTH:] history.append({"role": "user", "content": message.strip()}) yield "", history, gr.update(value="Stop", interactive=True) partial_response = "" assistant_message_added = False try: for response_chunk in model_manager.generate_response_stream( message.strip(), max_tokens, temperature, top_p ): if generation_stopped.is_set(): break partial_response = response_chunk if not assistant_message_added: current_history = history + [{"role": "assistant", "content": partial_response}] else: current_history = history[:-1] + [{"role": "assistant", "content": partial_response}] yield "", current_history, gr.update(value="Stop", interactive=True) if not assistant_message_added: history.append({"role": "assistant", "content": partial_response}) assistant_message_added = True else: history[-1] = {"role": "assistant", "content": partial_response} except Exception as e: if Config.DEBUG_MODE: logger.error(f"Chat interface error: {e}") error_msg = {"role": "assistant", "content": Config.ERROR_MSG} if not assistant_message_added: history.append(error_msg) else: history[-1] = error_msg yield "", history, gr.update(value="Envoyer", interactive=True) return yield "", history, gr.update(value="Envoyer", interactive=True) def create_interface() -> gr.Blocks: """Create Claude-like interface with fixed input and collapsible sidebar""" # CSS is unchanged custom_css = """ /* Full screen layout */ .gradio-container { max-width: 100% !important; width: 100% !important; margin: 0 !important; padding: 0 !important; min-height: 100vh !important; display: flex !important; flex-direction: column !important; } .main-content { flex: 1 !important; display: flex !important; height: 100vh !important; } .chat-area { flex: 1 !important; display: flex !important; flex-direction: column !important; height: 100% !important; } .chat-container { flex: 1 !important; height: calc(100vh - 200px) !important; min-height: 400px !important; border: none !important; } .input-container { position: sticky !important; bottom: 0 !important; background: white !important; border-top: 1px solid #e0e0e0 !important; padding: 1rem !important; z-index: 100 !important; } .input-row { display: flex !important; gap: 0.5rem !important; align-items: flex-end !important; } .message-input { flex: 1 !important; min-height: 24px !important; max-height: 120px !important; resize: vertical !important; } .send-button { min-width: 80px !important; height: 40px !important; margin-left: 0.5rem !important; } .sidebar { width: 300px !important; min-width: 300px !important; border-left: 1px solid #e0e0e0 !important; background: #f8f9fa !important; transition: margin-right 0.3s ease !important; overflow-y: auto !important; padding: 1rem !important; } .sidebar.collapsed { margin-right: -300px !important; } .sidebar-toggle { position: fixed !important; top: 20px !important; right: 20px !important; z-index: 200 !important; width: 40px !important; height: 40px !important; border-radius: 50% !important; background: #007bff !important; color: white !important; border: none !important; cursor: pointer !important; display: flex !important; align-items: center !important; justify-content: center !important; font-size: 18px !important; } .header { padding: 1rem !important; border-bottom: 1px solid #e0e0e0 !important; background: white !important; } .status-indicator { padding: 8px 12px !important; border-radius: 4px !important; margin: 0.5rem 0 !important; font-weight: 500 !important; font-size: 0.9rem !important; } .status-success { background-color: #d4edda !important; color: #155724 !important; border: 1px solid #c3e6cb !important; } .status-error { background-color: #f8d7da !important; color: #721c24 !important; border: 1px solid #f5c6cb !important; } .action-buttons { display: flex !important; gap: 0.5rem !important; margin-bottom: 1rem !important; } @media (max-width: 768px) { .sidebar { width: 280px !important; min-width: 280px !important; } .sidebar.collapsed { margin-right: -280px !important; } .chat-container { height: calc(100vh - 160px) !important; } } """ with gr.Blocks( title="Assistant Éducatif", theme=gr.themes.Soft(), css=custom_css ) as demo: sidebar_visible = gr.State(True) with gr.Row(elem_classes="header"): with gr.Column(): gr.Markdown("# Assistant Éducatif Intelligent") if model_manager.models_loaded: gr.Markdown( "**SYSTÈME OPÉRATIONNEL** - Prêt à répondre", elem_classes="status-indicator status-success" ) else: gr.Markdown( "**CHARGEMENT EN COURS / ERREUR** - Veuillez vérifier les logs.", elem_classes="status-indicator status-error" ) with gr.Row(elem_classes="main-content"): with gr.Column(elem_classes="chat-area"): chatbot = gr.Chatbot( label="", type='messages', elem_classes="chat-container", show_copy_button=True, show_share_button=False, height="100%" ) with gr.Row(elem_classes="action-buttons"): clear_btn = gr.Button("Effacer la conversation", size="sm", variant="secondary") retry_btn = gr.Button("Réessayer", size="sm", variant="secondary") with gr.Row(elem_classes="input-container"): with gr.Column(elem_classes="input-row"): user_input = gr.Textbox( label="", placeholder="Posez votre question...", lines=1, max_lines=4, elem_classes="message-input", show_label=False ) with gr.Row(): submit_btn = gr.Button( "Envoyer", variant="primary", elem_classes="send-button" ) stop_btn = gr.Button( "Stop", variant="stop", visible=False, elem_classes="send-button" ) with gr.Column(elem_classes="sidebar", visible=True) as sidebar: gr.Markdown("### Paramètres de génération") max_tokens = gr.Slider(minimum=50, maximum=Config.MAX_TOKENS_LIMIT, value=Config.MAX_TOKENS_DEFAULT, step=10, label="Longueur maximale") temperature = gr.Slider(minimum=0.1, maximum=1.0, value=Config.TEMPERATURE_DEFAULT, step=0.05, label="Créativité") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=Config.TOP_P_DEFAULT, step=0.05, label="Diversité") gr.Markdown("---") gr.Markdown("### Informations système") system_info = f""" **Device:** {Config.DEVICE.upper()} **Précision:** {str(Config.TORCH_DTYPE).split('.')[-1]} **Status:** {'✓ Prêt' if model_manager.models_loaded else '⏳ Chargement / Erreur'} """ if model_manager.load_start_time and model_manager.models_loaded: load_time = time.time() - model_manager.load_start_time system_info += f" \n**Temps:** {load_time:.1f}s" gr.Markdown(system_info) with gr.Accordion("Conseils d'utilisation", open=False): gr.Markdown(""" - Questions claires et précises - Spécifiez le niveau si pertinent - Utilisez un langage approprié """) toggle_btn = gr.Button("⚙", elem_classes="sidebar-toggle") # Event handlers def clear_conversation(): return [], "" def retry_last_message(history): if history and len(history) >= 2 and history[-2]["role"] == "user": last_user_msg = history[-2]["content"] new_history = history[:-2] return last_user_msg, new_history elif history and history[-1]["role"] == "user": last_user_msg = history[-1]["content"] new_history = history[:-1] return last_user_msg, new_history return "", history def toggle_sidebar(visible): return not visible, gr.update(visible=not visible) def handle_stop(): stop_generation() return gr.update(value="Envoyer", interactive=True) # Wire up events submit_action = user_input.submit( chat_interface, inputs=[user_input, chatbot, max_tokens, temperature, top_p], outputs=[user_input, chatbot, submit_btn], show_progress=False ) click_action = submit_btn.click( chat_interface, inputs=[user_input, chatbot, max_tokens, temperature, top_p], outputs=[user_input, chatbot, submit_btn], show_progress=False ) stop_btn.click(handle_stop, outputs=[submit_btn]) clear_btn.click(clear_conversation, outputs=[chatbot, user_input]) retry_btn.click(retry_last_message, inputs=[chatbot], outputs=[user_input, chatbot]).then( chat_interface, inputs=[user_input, chatbot, max_tokens, temperature, top_p], outputs=[user_input, chatbot, submit_btn], show_progress=False ) toggle_btn.click(toggle_sidebar, inputs=[sidebar_visible], outputs=[sidebar_visible, sidebar]) return demo # Load models logger.info("Initializing Educational Assistant...") models_loaded = model_manager.load_models() if not models_loaded: logger.error("Failed to load models. Application may not function properly.") # Create interface demo = create_interface() # Launch configuration if __name__ == "__main__": try: demo.queue(max_size=20) launch_kwargs = { "server_name": "0.0.0.0", "server_port": 7860, "show_error": True, "show_api": False, "share": False, } if os.getenv("GRADIO_AUTH"): auth_pairs = [tuple(auth.split(':')) for auth in os.getenv("GRADIO_AUTH").split(',')] launch_kwargs["auth"] = auth_pairs logger.info("Authentication enabled") logger.info("Launching application...") demo.launch(**launch_kwargs) except KeyboardInterrupt: logger.info("Application stopped by user") except Exception as e: logger.error(f"Launch error: {e}") finally: model_manager.cleanup() logger.info("Application shutdown complete")