Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import time | |
| import torch | |
| import gradio as gr | |
| from typing import List, Dict, Any, Tuple | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TextIteratorStreamer, | |
| ) | |
| from huggingface_hub import login | |
| import threading | |
| import spaces | |
| """ | |
| Gradio chat app for facebook/MobileLLM-Pro | |
| - Uses the model's chat template when using the "instruct" subfolder | |
| - Streams tokens to the Gradio UI | |
| - Minimal controls: max_new_tokens, temperature, top_p | |
| - Optional HF_TOKEN login via env var or textbox | |
| To run locally: | |
| pip install -U gradio transformers accelerate sentencepiece huggingface_hub | |
| HF_TOKEN=xxxx python app.py | |
| On Hugging Face Spaces: | |
| - Remove explicit login() call or set HF_TOKEN as a secret | |
| """ | |
| MODEL_ID = "facebook/MobileLLM-Pro" | |
| DEFAULT_VERSION = "instruct" # "base" | "instruct" | |
| DEFAULT_MAX_NEW_TOKENS = 256 | |
| DEFAULT_TEMPERATURE = 0.7 | |
| DEFAULT_TOP_P = 0.95 | |
| # ---- Optional: login to Hugging Face if token is provided ---- | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN: | |
| try: | |
| login(token=HF_TOKEN) | |
| print("[INFO] Logged in to Hugging Face Hub.") | |
| except Exception as e: | |
| print(f"[WARN] Could not login to Hugging Face: {e}") | |
| def load_model(version: str = DEFAULT_VERSION): | |
| """Load tokenizer+model for the selected subfolder (base/instruct).""" | |
| print(f"[INFO] Loading {MODEL_ID}:{version} ...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, trust_remote_code=True, subfolder=version | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| subfolder=version, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| # Ensure special tokens are set to avoid warnings | |
| if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.eval() | |
| print("[INFO] Model loaded.") | |
| return tokenizer, model | |
| def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]: | |
| """Map Gradio history [(user, assistant), ...] to chat template messages.""" | |
| messages: List[Dict[str, str]] = [] | |
| for user_msg, bot_msg in history: | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if bot_msg: | |
| messages.append({"role": "assistant", "content": bot_msg}) | |
| return messages | |
| def generate_stream( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| version: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| use_chat_template: bool, | |
| state: Dict[str, Any], | |
| ): | |
| """Streaming text generator compatible with gr.ChatInterface. | |
| Args map to UI controls. `state` holds tokenizer/model between calls. | |
| """ | |
| tokenizer = state.get("tokenizer") | |
| model = state.get("model") | |
| # (Re)load model if version changed or not yet loaded | |
| if ( | |
| tokenizer is None | |
| or model is None | |
| or state.get("version") != version | |
| ): | |
| tokenizer, model = load_model(version) | |
| state["tokenizer"], state["model"], state["version"] = tokenizer, model, version | |
| device = next(model.parameters()).device | |
| if use_chat_template and version == "instruct": | |
| messages = _history_to_messages(history) + [ | |
| {"role": "user", "content": message} | |
| ] | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| ).to(device) | |
| input_ids = inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"] | |
| else: | |
| input_ids = tokenizer( | |
| message, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| )["input_ids"].to(device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
| gen_kwargs = dict( | |
| input_ids=input_ids, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=temperature > 0.0, | |
| temperature=max(0.0, float(temperature)), | |
| top_p=float(top_p), | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| streamer=streamer, | |
| ) | |
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| output_text = "" | |
| for new_text in streamer: | |
| output_text += new_text | |
| yield output_text | |
| with gr.Blocks(title="MobileLLM-Pro Chat") as demo: | |
| gr.Markdown(""" | |
| # facebook/MobileLLM-Pro — Chat Demo | |
| - **Version**: choose `instruct` to enable the model's chat template. | |
| - **Streaming** is enabled. Use the controls in the right panel. | |
| """) | |
| gr.Markdown( | |
| "<div style='text-align: center;'>Built with <a href='https://huggingface.co/spaces/akhaliq/anycoder'>anycoder</a></div>", | |
| elem_id="anycoder_attribution" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(height=420, label="MobileLLM-Pro") | |
| msg = gr.Textbox(placeholder="Ask me anything…", scale=1) | |
| submit = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear chat") | |
| with gr.Column(scale=2): | |
| version = gr.Dropdown(["base", "instruct"], value=DEFAULT_VERSION, label="Subfolder (version)") | |
| use_chat_template = gr.Checkbox(value=True, label="Use chat template (instruct only)") | |
| max_new = gr.Slider(32, 1024, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="Max new tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=DEFAULT_TOP_P, step=0.01, label="Top-p") | |
| hf_token_box = gr.Textbox(value=os.getenv("HF_TOKEN", ""), label="HF_TOKEN (optional)") | |
| state = gr.State({"tokenizer": None, "model": None, "version": None}) | |
| def _maybe_login(token: str): | |
| token = (token or "").strip() | |
| if not token: | |
| return "(No token provided; skipping login)" | |
| try: | |
| login(token=token) | |
| return "Logged in to Hugging Face Hub." | |
| except Exception as e: | |
| return f"Login failed: {e}" | |
| login_btn = gr.Button("Login to HF (optional)") | |
| login_status = gr.Markdown() | |
| login_btn.click(_maybe_login, inputs=[hf_token_box], outputs=[login_status]) | |
| def user_submit(user_message, chat_history): | |
| # Immediately append the user's message so the stream shows inline | |
| return "", chat_history + [(user_message, None)] | |
| def bot_respond(chat_history, version, max_new, temperature, top_p, use_chat_template, state): | |
| # The last tuple is (user, None) | |
| user_message = chat_history[-1][0] if chat_history else "" | |
| partials = generate_stream( | |
| user_message, | |
| chat_history[:-1], | |
| version, | |
| int(max_new), | |
| float(temperature), | |
| float(top_p), | |
| bool(use_chat_template), | |
| state, | |
| ) | |
| # Stream tokens to the last assistant message slot | |
| for chunk in partials: | |
| chat_history[-1] = (chat_history[-1][0], chunk) | |
| yield chat_history | |
| msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then( | |
| bot_respond, | |
| [chatbot, version, max_new, temperature, top_p, use_chat_template, state], | |
| [chatbot], | |
| ) | |
| submit.click(user_submit, [msg, chatbot], [msg, chatbot]).then( | |
| bot_respond, | |
| [chatbot, version, max_new, temperature, top_p, use_chat_template, state], | |
| [chatbot], | |
| ) | |
| def clear_chat(): | |
| return [] | |
| clear_btn.click(clear_chat, outputs=[chatbot]) | |
| if __name__ == "__main__": | |
| # For Spaces, Gradio will call `demo.launch()` automatically; locally we launch here. | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860))) |