UserLM-8B / app.py
multimodalart's picture
Create app.py
6e8ca12 verified
from __future__ import annotations
import os
from typing import List, Tuple, Dict, Any
import spaces
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# ----------------------
# Config
# ----------------------
MODEL_ID = os.getenv("MODEL_ID", "microsoft/UserLM-8b")
DEFAULT_SYSTEM_PROMPT = (
"You are a user who wants to implement a special type of sequence. "
"The sequence sums up the two previous numbers in the sequence and adds 1 to the result. "
"The first two numbers in the sequence are 1 and 1."
)
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(model_id: str = MODEL_ID):
"""Load tokenizer and model, with a reasonable dtype and device fallback."""
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
dtype = torch.float16 if device == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=dtype,
)
# Special tokens for stopping / filtering
end_token = "<|eot_id|>"
end_conv_token = "<|endconversation|>"
end_token_ids = tokenizer.encode(end_token, add_special_tokens=False)
end_conv_token_ids = tokenizer.encode(end_conv_token, add_special_tokens=False)
# Some models may not include these tokens — handle gracefully
eos_token_id = end_token_ids[0] if len(end_token_ids) > 0 else tokenizer.eos_token_id
bad_words_ids = (
[[tid] for tid in end_conv_token_ids] if len(end_conv_token_ids) > 0 else None
)
return tokenizer, model, eos_token_id, bad_words_ids
tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS = load_model()
model = model.to(device)
model.eval()
# ----------------------
# Generation helper
# ----------------------
def build_messages(system_prompt: str, history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
"""Transform Gradio history [(user, assistant), ...] into chat template messages."""
messages: List[Dict[str, str]] = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt.strip()})
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
return messages
@spaces.GPU
def generate_reply(
messages: List[Dict[str, str]],
max_new_tokens: int = 256,
temperature: float = 0.8,
top_p: float = 0.9,
) -> str:
"""Run a single generate() step and return the model's text reply."""
# Prepare input ids using the model's chat template
inputs = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True,
).to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs,
do_sample=True,
top_p=top_p,
temperature=temperature,
max_new_tokens=max_new_tokens,
eos_token_id=EOS_TOKEN_ID,
pad_token_id=tokenizer.eos_token_id,
bad_words_ids=BAD_WORDS_IDS,
)
# Slice off the prompt tokens to get only the new text
generated = outputs[0][inputs.shape[1]:]
text = tokenizer.decode(generated, skip_special_tokens=True).strip()
return text
# ----------------------
# Gradio UI callbacks
# ----------------------
def respond(user_message: str, chat_history: List[Tuple[str, str]], system_prompt: str,
max_new_tokens: int, temperature: float, top_p: float):
# Build messages including prior turns
messages = build_messages(system_prompt, chat_history + [(user_message, "")])
try:
reply = generate_reply(
messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
except Exception as e:
reply = f"(Generation error: {e})"
chat_history = chat_history + [(user_message, reply)]
return chat_history, chat_history
def clear_state():
return [], DEFAULT_SYSTEM_PROMPT
# ----------------------
# Build the Gradio App
# ----------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🧪 Transformers × Gradio: Multi‑turn Chat Demo
Model: **{model}** on **{device}**
Change the system prompt, then chat. Sliders control sampling.
""".format(model=MODEL_ID, device=device))
with gr.Row():
system_box = gr.Textbox(
label="System Prompt",
value=DEFAULT_SYSTEM_PROMPT,
lines=3,
placeholder="Enter a system instruction to steer the assistant",
)
chatbot = gr.Chatbot(height=420, label="Chat")
with gr.Row():
msg = gr.Textbox(
label="Your message",
placeholder="Type a message and press Enter",
)
with gr.Accordion("Generation Settings", open=False):
max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="max_new_tokens")
temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p")
with gr.Row():
submit_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear")
state = gr.State([]) # chat history state: List[Tuple[user, assistant]]
def _submit(user_text, history, system_prompt, mnt, temp, tp):
if not user_text or not user_text.strip():
return gr.update(), history
new_history, visible = respond(user_text.strip(), history, system_prompt, mnt, temp, tp)
return "", visible
submit_btn.click(
fn=_submit,
inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
outputs=[msg, chatbot],
)
msg.submit(
fn=_submit,
inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
outputs=[msg, chatbot],
)
# Keep state in sync with the visible Chatbot
def _sync_state(chat):
return chat
chatbot.change(_sync_state, inputs=[chatbot], outputs=[state])
def _clear():
history, sys = clear_state()
return history, sys, history, ""
clear_btn.click(_clear, outputs=[state, system_box, chatbot, msg])
if __name__ == "__main__":
demo.queue().launch() # enable queuing for concurrency