Spaces:
Running
on
Zero
Running
on
Zero
import re | |
import threading | |
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
CSS = """ | |
.m3d-auto-scroll > * { | |
overflow: auto; | |
} | |
#reasoning { | |
overflow: auto; | |
height: calc(100vh - 128px); | |
scroll-behavior: smooth; | |
} | |
""" | |
JS = """ | |
() => { | |
// auto scroll .auto-scroll elements when text has changed | |
const block = document.querySelector('#reasoning'); | |
const observer = new MutationObserver((mutations) => { | |
block.scrollTop = block.scrollHeight; | |
}) | |
observer.observe(block, { | |
childList: true, | |
characterData: true, | |
subtree: true, | |
}); | |
} | |
""" | |
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype="auto", | |
device_map="auto", | |
) | |
print(dir(model)) | |
print(model.config) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def reformat_math(text): | |
"""Fix MathJax delimiters to use the Gradio syntax. | |
This is a workaround to display math formulas in Gradio. For now, I havn't found a way to | |
make it work as expected using others latex_delimites... | |
""" | |
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL) | |
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL) | |
return text | |
def chat(prompt, history): | |
"""Respond to a chat prompt.""" | |
message = { | |
"role": "user", | |
"content": prompt, | |
} | |
# build the messages list | |
history = [] if history is None else history | |
message_list = history + [message] | |
text = tokenizer.apply_chat_template( | |
message_list, | |
tokenize=False, | |
add_generation_prompt=True, | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
threading.Thread( | |
target=model.generate, | |
kwargs=dict( | |
max_new_tokens=1024 * 128, | |
streamer=streamer, | |
**model_inputs, | |
), | |
).start() | |
buffer = "" | |
reasoning = "" | |
thinking = False | |
reasoning_heading = "# Reasoning\n\n" | |
for new_text in streamer: | |
if not thinking and "<think>" in new_text: | |
thinking = True | |
continue | |
if thinking and "</think>" in new_text: | |
thinking = False | |
continue | |
if thinking: | |
reasoning += new_text | |
yield ( | |
"I'm thinking, please wait a moment...", | |
reasoning_heading + reasoning, | |
) | |
continue | |
buffer += new_text | |
yield reformat_math(buffer), reasoning_heading + reasoning | |
chat_bot = gr.Chatbot( | |
latex_delimiters=[ | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "$", "right": "$", "display": False}, | |
], | |
scale=1, | |
type="messages", | |
) | |
with gr.Blocks( | |
theme="davehornik/Tealy", | |
js=JS, | |
css=CSS, | |
fill_height=True, | |
title="Reasoning model example", | |
) as demo: | |
reasoning = gr.Markdown( | |
"# Reasoning\n\nWhen the model will reasoning, its thoughts will be displayed here.", | |
label="Reasoning", | |
show_label=True, | |
container=True, | |
elem_classes="m3d-auto-scroll", | |
render=False, | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3, variant="compact"): | |
gr.ChatInterface( | |
chat, | |
type="messages", | |
chatbot=chat_bot, | |
title="Simple conversational AI with reasoning", | |
description=( | |
f"We're using the **{model_name}**. It is a large language model " | |
"trained on a mixture of instruction and " | |
"conversational data. It has the capability to reason about the " | |
"prompt (the user question). " | |
"When you ask a question, you can see its thoughts " | |
"on the left block." | |
), | |
additional_outputs=[reasoning], | |
) | |
with gr.Column(elem_id="reasoning"): | |
reasoning.render() | |
if __name__ == "__main__": | |
demo.queue().launch() | |