Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from huggingface_hub import login | |
| import os | |
| from typing import List, Dict, Any | |
| import time | |
| # Configuration | |
| MODEL_ID = "facebook/MobileLLM-Pro" | |
| MAX_HISTORY_LENGTH = 10 | |
| MAX_NEW_TOKENS = 512 | |
| DEFAULT_SYSTEM_PROMPT = "You are a helpful, friendly, and intelligent assistant. Provide clear, accurate, and thoughtful responses." | |
| # Login to Hugging Face (if token is provided) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN: | |
| try: | |
| login(token=HF_TOKEN) | |
| print("Successfully logged in to Hugging Face") | |
| except Exception as e: | |
| print(f"Warning: Could not login to Hugging Face: {e}") | |
| class MobileLLMChat: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.model_loaded = False | |
| def load_model(self, version="instruct"): | |
| """Load the MobileLLM-Pro model and tokenizer""" | |
| try: | |
| print(f"Loading MobileLLM-Pro ({version})...") | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| subfolder=version | |
| ) | |
| # Load model | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| subfolder=version, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| # Set device | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if not torch.cuda.is_available(): | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.model_loaded = True | |
| print(f"Model loaded successfully on {self.device}") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return False | |
| def format_chat_history(self, history: List[Dict[str, str]], system_prompt: str) -> List[Dict[str, str]]: | |
| """Format chat history for the model""" | |
| messages = [{"role": "system", "content": system_prompt}] | |
| for msg in history: | |
| if msg["role"] in ["user", "assistant"]: | |
| messages.append(msg) | |
| return messages | |
| def generate_response(self, user_input: str, history: List[Dict[str, str]], | |
| system_prompt: str, temperature: float = 0.7, | |
| max_new_tokens: int = MAX_NEW_TOKENS) -> str: | |
| """Generate a response from the model""" | |
| if not self.model_loaded: | |
| return "Model not loaded. Please try loading the model first." | |
| try: | |
| # Add user message to history | |
| history.append({"role": "user", "content": user_input}) | |
| # Format messages | |
| messages = self.format_chat_history(history, system_prompt) | |
| # Apply chat template | |
| inputs = self.tokenizer.apply_chat_template( | |
| messages, | |
| return_tensors="pt", | |
| add_generation_prompt=True | |
| ).to(self.device) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Decode response | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the new response (remove input) | |
| if response.startswith(messages[0]["content"]): | |
| response = response[len(messages[0]["content"]):].strip() | |
| # Remove the user input from the response | |
| if user_input in response: | |
| response = response.replace(user_input, "").strip() | |
| # Clean up common prefixes | |
| prefixes_to_remove = ["Assistant:", "assistant:", "Response:", "response:"] | |
| for prefix in prefixes_to_remove: | |
| if response.lower().startswith(prefix.lower()): | |
| response = response[len(prefix):].strip() | |
| # Add assistant response to history | |
| history.append({"role": "assistant", "content": response}) | |
| return response | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| def generate_stream(self, user_input: str, history: List[Dict[str, str]], | |
| system_prompt: str, temperature: float = 0.7): | |
| """Generate a streaming response from the model""" | |
| if not self.model_loaded: | |
| yield "Model not loaded. Please try loading the model first." | |
| return | |
| try: | |
| # Add user message to history | |
| history.append({"role": "user", "content": user_input}) | |
| # Format messages | |
| messages = self.format_chat_history(history, system_prompt) | |
| # Apply chat template | |
| inputs = self.tokenizer.apply_chat_template( | |
| messages, | |
| return_tensors="pt", | |
| add_generation_prompt=True | |
| ).to(self.device) | |
| # Generate streaming response | |
| generated_text = "" | |
| for token_id in self.model.generate( | |
| inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| temperature=temperature, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| streamer=None, | |
| ): | |
| # Decode current token | |
| new_token = self.tokenizer.decode(token_id[-1:], skip_special_tokens=True) | |
| generated_text += new_token | |
| # Extract only the new response | |
| response = generated_text | |
| if response.startswith(messages[0]["content"]): | |
| response = response[len(messages[0]["content"]):].strip() | |
| if user_input in response: | |
| response = response.replace(user_input, "").strip() | |
| # Clean up common prefixes | |
| prefixes_to_remove = ["Assistant:", "assistant:", "Response:", "response:"] | |
| for prefix in prefixes_to_remove: | |
| if response.lower().startswith(prefix.lower()): | |
| response = response[len(prefix):].strip() | |
| yield response | |
| # Stop if we hit end of sentence | |
| if new_token in ["</s>", "<|endoftext|>", "."] and len(response) > 50: | |
| break | |
| # Add final response to history | |
| history.append({"role": "assistant", "content": response}) | |
| except Exception as e: | |
| yield f"Error generating response: {str(e)}" | |
| # Initialize chat model | |
| chat_model = MobileLLMChat() | |
| def load_model_button(version): | |
| """Load the model when button is clicked""" | |
| success = chat_model.load_model(version) | |
| if success: | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(value="Model loaded successfully!") | |
| else: | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(value="Failed to load model. Please check the logs.") | |
| def clear_chat(): | |
| """Clear the chat history""" | |
| return [], [] | |
| def chat_fn(message, history, system_prompt, temperature, model_version): | |
| """Main chat function""" | |
| if not chat_model.model_loaded: | |
| return "Please load the model first using the button above." | |
| # Convert history format | |
| formatted_history = [] | |
| for user_msg, assistant_msg in history: | |
| formatted_history.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| formatted_history.append({"role": "assistant", "content": assistant_msg}) | |
| # Generate response | |
| response = chat_model.generate_response(message, formatted_history, system_prompt, temperature) | |
| return response | |
| def chat_stream_fn(message, history, system_prompt, temperature, model_version): | |
| """Streaming chat function""" | |
| if not chat_model.model_loaded: | |
| yield "Please load the model first using the button above." | |
| return | |
| # Convert history format | |
| formatted_history = [] | |
| for user_msg, assistant_msg in history: | |
| formatted_history.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| formatted_history.append({"role": "assistant", "content": assistant_msg}) | |
| # Generate streaming response | |
| for chunk in chat_model.generate_stream(message, formatted_history, system_prompt, temperature): | |
| yield chunk | |
| # Create the Gradio interface | |
| with gr.Blocks( | |
| title="MobileLLM-Pro Chat", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 900px !important; | |
| margin: auto !important; | |
| } | |
| .message { | |
| padding: 12px !important; | |
| border-radius: 8px !important; | |
| margin-bottom: 8px !important; | |
| } | |
| .user-message { | |
| background-color: #e3f2fd !important; | |
| margin-left: 20% !important; | |
| } | |
| .assistant-message { | |
| background-color: #f5f5f5 !important; | |
| margin-right: 20% !important; | |
| } | |
| """ | |
| ) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1>๐ค MobileLLM-Pro Chat</h1> | |
| <p>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p> | |
| <p>Chat with Facebook's MobileLLM-Pro model optimized for on-device inference</p> | |
| </div> | |
| """) | |
| # Model loading section | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_version = gr.Dropdown( | |
| choices=["instruct", "base"], | |
| value="instruct", | |
| label="Model Version", | |
| info="Choose between instruct (chat) or base model" | |
| ) | |
| load_btn = gr.Button("๐ Load Model", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| model_status = gr.Textbox( | |
| label="Model Status", | |
| value="Model not loaded", | |
| interactive=False | |
| ) | |
| # Configuration section | |
| with gr.Accordion("โ๏ธ Configuration", open=False): | |
| with gr.Row(): | |
| system_prompt = gr.Textbox( | |
| value=DEFAULT_SYSTEM_PROMPT, | |
| label="System Prompt", | |
| lines=3, | |
| info="Customize the AI's behavior and personality" | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Controls randomness (higher = more creative)" | |
| ) | |
| streaming = gr.Checkbox( | |
| value=True, | |
| label="Enable Streaming", | |
| info="Show responses as they're being generated" | |
| ) | |
| # Chat interface | |
| chatbot = gr.Chatbot( | |
| label="Chat History", | |
| height=500, | |
| show_copy_button=True, | |
| bubble_full_width=False, | |
| type="messages" | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| scale=4, | |
| container=False | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| clear_btn = gr.Button("Clear", scale=0) | |
| # Event handlers | |
| load_btn.click( | |
| load_model_button, | |
| inputs=[model_version], | |
| outputs=[load_btn, model_status, model_status] | |
| ) | |
| # Handle chat submission | |
| def handle_chat(message, history, system_prompt, temperature, model_version, streaming): | |
| if streaming: | |
| return chat_stream_fn(message, history, system_prompt, temperature, model_version) | |
| else: | |
| return chat_fn(message, history, system_prompt, temperature, model_version) | |
| msg.submit( | |
| handle_chat, | |
| inputs=[msg, chatbot, system_prompt, temperature, model_version, streaming], | |
| outputs=[chatbot] | |
| ) | |
| submit_btn.click( | |
| handle_chat, | |
| inputs=[msg, chatbot, system_prompt, temperature, model_version, streaming], | |
| outputs=[chatbot] | |
| ) | |
| clear_btn.click( | |
| clear_chat, | |
| outputs=[chatbot, msg] | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["What are the benefits of on-device AI models?"], | |
| ["Explain quantum computing in simple terms."], | |
| ["Write a short poem about technology."], | |
| ["What's the difference between machine learning and deep learning?"], | |
| ["How can I improve my productivity?"], | |
| ], | |
| inputs=[msg], | |
| label="Example Prompts" | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 20px; color: #666;"> | |
| <p>โ ๏ธ Note: This model requires significant computational resources. Loading may take a few minutes.</p> | |
| <p>Model: <a href="https://huggingface.co/facebook/MobileLLM-Pro" target="_blank">facebook/MobileLLM-Pro</a></p> | |
| </div> | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=True, | |
| show_error=True, | |
| show_tips=True, | |
| debug=True | |
| ) |