Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import os | |
| from typing import List, Tuple, Dict, Any | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ---------------------- | |
| # Config | |
| # ---------------------- | |
| MODEL_ID = os.getenv("MODEL_ID", "microsoft/UserLM-8b") | |
| DEFAULT_SYSTEM_PROMPT = ( | |
| "You are a user who wants to implement a special type of sequence. " | |
| "The sequence sums up the two previous numbers in the sequence and adds 1 to the result. " | |
| "The first two numbers in the sequence are 1 and 1." | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(model_id: str = MODEL_ID): | |
| """Load tokenizer and model, with a reasonable dtype and device fallback.""" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| ) | |
| # Special tokens for stopping / filtering | |
| end_token = "<|eot_id|>" | |
| end_conv_token = "<|endconversation|>" | |
| end_token_ids = tokenizer.encode(end_token, add_special_tokens=False) | |
| end_conv_token_ids = tokenizer.encode(end_conv_token, add_special_tokens=False) | |
| # Some models may not include these tokens — handle gracefully | |
| eos_token_id = end_token_ids[0] if len(end_token_ids) > 0 else tokenizer.eos_token_id | |
| bad_words_ids = ( | |
| [[tid] for tid in end_conv_token_ids] if len(end_conv_token_ids) > 0 else None | |
| ) | |
| return tokenizer, model, eos_token_id, bad_words_ids | |
| tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS = load_model() | |
| model = model.to(device) | |
| model.eval() | |
| # ---------------------- | |
| # Generation helper | |
| # ---------------------- | |
| def build_messages(system_prompt: str, history: List[Tuple[str, str]]) -> List[Dict[str, str]]: | |
| """Transform Gradio history [(user, assistant), ...] into chat template messages.""" | |
| messages: List[Dict[str, str]] = [] | |
| if system_prompt.strip(): | |
| messages.append({"role": "system", "content": system_prompt.strip()}) | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| return messages | |
| def generate_reply( | |
| messages: List[Dict[str, str]], | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.8, | |
| top_p: float = 0.9, | |
| ) -> str: | |
| """Run a single generate() step and return the model's text reply.""" | |
| # Prepare input ids using the model's chat template | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=inputs, | |
| do_sample=True, | |
| top_p=top_p, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| eos_token_id=EOS_TOKEN_ID, | |
| pad_token_id=tokenizer.eos_token_id, | |
| bad_words_ids=BAD_WORDS_IDS, | |
| ) | |
| # Slice off the prompt tokens to get only the new text | |
| generated = outputs[0][inputs.shape[1]:] | |
| text = tokenizer.decode(generated, skip_special_tokens=True).strip() | |
| return text | |
| # ---------------------- | |
| # Gradio UI callbacks | |
| # ---------------------- | |
| def respond(user_message: str, chat_history: List[Tuple[str, str]], system_prompt: str, | |
| max_new_tokens: int, temperature: float, top_p: float): | |
| # Build messages including prior turns | |
| messages = build_messages(system_prompt, chat_history + [(user_message, "")]) | |
| try: | |
| reply = generate_reply( | |
| messages, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| except Exception as e: | |
| reply = f"(Generation error: {e})" | |
| chat_history = chat_history + [(user_message, reply)] | |
| return chat_history, chat_history | |
| def clear_state(): | |
| return [], DEFAULT_SYSTEM_PROMPT | |
| # ---------------------- | |
| # Build the Gradio App | |
| # ---------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🧪 Transformers × Gradio: Multi‑turn Chat Demo | |
| Model: **{model}** on **{device}** | |
| Change the system prompt, then chat. Sliders control sampling. | |
| """.format(model=MODEL_ID, device=device)) | |
| with gr.Row(): | |
| system_box = gr.Textbox( | |
| label="System Prompt", | |
| value=DEFAULT_SYSTEM_PROMPT, | |
| lines=3, | |
| placeholder="Enter a system instruction to steer the assistant", | |
| ) | |
| chatbot = gr.Chatbot(height=420, label="Chat") | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your message", | |
| placeholder="Type a message and press Enter", | |
| ) | |
| with gr.Accordion("Generation Settings", open=False): | |
| max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="max_new_tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| state = gr.State([]) # chat history state: List[Tuple[user, assistant]] | |
| def _submit(user_text, history, system_prompt, mnt, temp, tp): | |
| if not user_text or not user_text.strip(): | |
| return gr.update(), history | |
| new_history, visible = respond(user_text.strip(), history, system_prompt, mnt, temp, tp) | |
| return "", visible | |
| submit_btn.click( | |
| fn=_submit, | |
| inputs=[msg, state, system_box, max_new_tokens, temperature, top_p], | |
| outputs=[msg, chatbot], | |
| ) | |
| msg.submit( | |
| fn=_submit, | |
| inputs=[msg, state, system_box, max_new_tokens, temperature, top_p], | |
| outputs=[msg, chatbot], | |
| ) | |
| # Keep state in sync with the visible Chatbot | |
| def _sync_state(chat): | |
| return chat | |
| chatbot.change(_sync_state, inputs=[chatbot], outputs=[state]) | |
| def _clear(): | |
| history, sys = clear_state() | |
| return history, sys, history, "" | |
| clear_btn.click(_clear, outputs=[state, system_box, chatbot, msg]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() # enable queuing for concurrency | |