MobileLLM-Pro / app.py
akhaliq's picture
akhaliq HF Staff
Update app.py
026bd03 verified
raw
history blame
14.2 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import os
from typing import List, Dict, Any
import time
# Configuration
MODEL_ID = "facebook/MobileLLM-Pro"
MAX_HISTORY_LENGTH = 10
MAX_NEW_TOKENS = 512
DEFAULT_SYSTEM_PROMPT = "You are a helpful, friendly, and intelligent assistant. Provide clear, accurate, and thoughtful responses."
# Login to Hugging Face (if token is provided)
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
try:
login(token=HF_TOKEN)
print("Successfully logged in to Hugging Face")
except Exception as e:
print(f"Warning: Could not login to Hugging Face: {e}")
class MobileLLMChat:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
def load_model(self, version="instruct"):
"""Load the MobileLLM-Pro model and tokenizer"""
try:
print(f"Loading MobileLLM-Pro ({version})...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
trust_remote_code=True,
subfolder=version
)
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
subfolder=version,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
# Set device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
self.model.to(self.device)
self.model.eval()
self.model_loaded = True
print(f"Model loaded successfully on {self.device}")
return True
except Exception as e:
print(f"Error loading model: {e}")
return False
def format_chat_history(self, history: List[Dict[str, str]], system_prompt: str) -> List[Dict[str, str]]:
"""Format chat history for the model"""
messages = [{"role": "system", "content": system_prompt}]
for msg in history:
if msg["role"] in ["user", "assistant"]:
messages.append(msg)
return messages
def generate_response(self, user_input: str, history: List[Dict[str, str]],
system_prompt: str, temperature: float = 0.7,
max_new_tokens: int = MAX_NEW_TOKENS) -> str:
"""Generate a response from the model"""
if not self.model_loaded:
return "Model not loaded. Please try loading the model first."
try:
# Add user message to history
history.append({"role": "user", "content": user_input})
# Format messages
messages = self.format_chat_history(history, system_prompt)
# Apply chat template
inputs = self.tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
).to(self.device)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Decode response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the new response (remove input)
if response.startswith(messages[0]["content"]):
response = response[len(messages[0]["content"]):].strip()
# Remove the user input from the response
if user_input in response:
response = response.replace(user_input, "").strip()
# Clean up common prefixes
prefixes_to_remove = ["Assistant:", "assistant:", "Response:", "response:"]
for prefix in prefixes_to_remove:
if response.lower().startswith(prefix.lower()):
response = response[len(prefix):].strip()
# Add assistant response to history
history.append({"role": "assistant", "content": response})
return response
except Exception as e:
return f"Error generating response: {str(e)}"
def generate_stream(self, user_input: str, history: List[Dict[str, str]],
system_prompt: str, temperature: float = 0.7):
"""Generate a streaming response from the model"""
if not self.model_loaded:
yield "Model not loaded. Please try loading the model first."
return
try:
# Add user message to history
history.append({"role": "user", "content": user_input})
# Format messages
messages = self.format_chat_history(history, system_prompt)
# Apply chat template
inputs = self.tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
).to(self.device)
# Generate streaming response
generated_text = ""
for token_id in self.model.generate(
inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
streamer=None,
):
# Decode current token
new_token = self.tokenizer.decode(token_id[-1:], skip_special_tokens=True)
generated_text += new_token
# Extract only the new response
response = generated_text
if response.startswith(messages[0]["content"]):
response = response[len(messages[0]["content"]):].strip()
if user_input in response:
response = response.replace(user_input, "").strip()
# Clean up common prefixes
prefixes_to_remove = ["Assistant:", "assistant:", "Response:", "response:"]
for prefix in prefixes_to_remove:
if response.lower().startswith(prefix.lower()):
response = response[len(prefix):].strip()
yield response
# Stop if we hit end of sentence
if new_token in ["</s>", "<|endoftext|>", "."] and len(response) > 50:
break
# Add final response to history
history.append({"role": "assistant", "content": response})
except Exception as e:
yield f"Error generating response: {str(e)}"
# Initialize chat model
chat_model = MobileLLMChat()
def load_model_button(version):
"""Load the model when button is clicked"""
success = chat_model.load_model(version)
if success:
return gr.update(visible=False), gr.update(visible=True), gr.update(value="Model loaded successfully!")
else:
return gr.update(visible=True), gr.update(visible=False), gr.update(value="Failed to load model. Please check the logs.")
def clear_chat():
"""Clear the chat history"""
return [], []
def chat_fn(message, history, system_prompt, temperature, model_version):
"""Main chat function"""
if not chat_model.model_loaded:
return "Please load the model first using the button above."
# Convert history format
formatted_history = []
for user_msg, assistant_msg in history:
formatted_history.append({"role": "user", "content": user_msg})
if assistant_msg:
formatted_history.append({"role": "assistant", "content": assistant_msg})
# Generate response
response = chat_model.generate_response(message, formatted_history, system_prompt, temperature)
return response
def chat_stream_fn(message, history, system_prompt, temperature, model_version):
"""Streaming chat function"""
if not chat_model.model_loaded:
yield "Please load the model first using the button above."
return
# Convert history format
formatted_history = []
for user_msg, assistant_msg in history:
formatted_history.append({"role": "user", "content": user_msg})
if assistant_msg:
formatted_history.append({"role": "assistant", "content": assistant_msg})
# Generate streaming response
for chunk in chat_model.generate_stream(message, formatted_history, system_prompt, temperature):
yield chunk
# Create the Gradio interface
with gr.Blocks(
title="MobileLLM-Pro Chat",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 900px !important;
margin: auto !important;
}
.message {
padding: 12px !important;
border-radius: 8px !important;
margin-bottom: 8px !important;
}
.user-message {
background-color: #e3f2fd !important;
margin-left: 20% !important;
}
.assistant-message {
background-color: #f5f5f5 !important;
margin-right: 20% !important;
}
"""
) as demo:
# Header
gr.HTML("""
<div style="text-align: center; margin-bottom: 20px;">
<h1>๐Ÿค– MobileLLM-Pro Chat</h1>
<p>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p>
<p>Chat with Facebook's MobileLLM-Pro model optimized for on-device inference</p>
</div>
""")
# Model loading section
with gr.Row():
with gr.Column(scale=1):
model_version = gr.Dropdown(
choices=["instruct", "base"],
value="instruct",
label="Model Version",
info="Choose between instruct (chat) or base model"
)
load_btn = gr.Button("๐Ÿš€ Load Model", variant="primary", size="lg")
with gr.Column(scale=2):
model_status = gr.Textbox(
label="Model Status",
value="Model not loaded",
interactive=False
)
# Configuration section
with gr.Accordion("โš™๏ธ Configuration", open=False):
with gr.Row():
system_prompt = gr.Textbox(
value=DEFAULT_SYSTEM_PROMPT,
label="System Prompt",
lines=3,
info="Customize the AI's behavior and personality"
)
with gr.Row():
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Controls randomness (higher = more creative)"
)
streaming = gr.Checkbox(
value=True,
label="Enable Streaming",
info="Show responses as they're being generated"
)
# Chat interface
chatbot = gr.Chatbot(
label="Chat History",
height=500,
show_copy_button=True,
bubble_full_width=False,
type="messages"
)
with gr.Row():
msg = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
scale=4,
container=False
)
submit_btn = gr.Button("Send", variant="primary", scale=1)
clear_btn = gr.Button("Clear", scale=0)
# Event handlers
load_btn.click(
load_model_button,
inputs=[model_version],
outputs=[load_btn, model_status, model_status]
)
# Handle chat submission
def handle_chat(message, history, system_prompt, temperature, model_version, streaming):
if streaming:
return chat_stream_fn(message, history, system_prompt, temperature, model_version)
else:
return chat_fn(message, history, system_prompt, temperature, model_version)
msg.submit(
handle_chat,
inputs=[msg, chatbot, system_prompt, temperature, model_version, streaming],
outputs=[chatbot]
)
submit_btn.click(
handle_chat,
inputs=[msg, chatbot, system_prompt, temperature, model_version, streaming],
outputs=[chatbot]
)
clear_btn.click(
clear_chat,
outputs=[chatbot, msg]
)
# Examples
gr.Examples(
examples=[
["What are the benefits of on-device AI models?"],
["Explain quantum computing in simple terms."],
["Write a short poem about technology."],
["What's the difference between machine learning and deep learning?"],
["How can I improve my productivity?"],
],
inputs=[msg],
label="Example Prompts"
)
# Footer
gr.HTML("""
<div style="text-align: center; margin-top: 20px; color: #666;">
<p>โš ๏ธ Note: This model requires significant computational resources. Loading may take a few minutes.</p>
<p>Model: <a href="https://huggingface.co/facebook/MobileLLM-Pro" target="_blank">facebook/MobileLLM-Pro</a></p>
</div>
""")
# Launch the app
if __name__ == "__main__":
demo.launch(
share=True,
show_error=True,
show_tips=True,
debug=True
)