Metal3d's picture
Remove variant blocks
01c18cb unverified
raw
history blame contribute delete
4.27 kB
import re
import threading
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
CSS = """
.m3d-auto-scroll > * {
overflow: auto;
}
#reasoning {
overflow: auto;
height: calc(100vh - 128px);
scroll-behavior: smooth;
}
"""
JS = """
() => {
// auto scroll .auto-scroll elements when text has changed
const block = document.querySelector('#reasoning');
const observer = new MutationObserver((mutations) => {
block.scrollTop = block.scrollHeight;
})
observer.observe(block, {
childList: true,
characterData: true,
subtree: true,
});
}
"""
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
)
print(dir(model))
print(model.config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def reformat_math(text):
"""Fix MathJax delimiters to use the Gradio syntax.
This is a workaround to display math formulas in Gradio. For now, I havn't found a way to
make it work as expected using others latex_delimites...
"""
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
return text
@spaces.GPU
def chat(prompt, history):
"""Respond to a chat prompt."""
message = {
"role": "user",
"content": prompt,
}
# build the messages list
history = [] if history is None else history
message_list = history + [message]
text = tokenizer.apply_chat_template(
message_list,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
threading.Thread(
target=model.generate,
kwargs=dict(
max_new_tokens=1024 * 128,
streamer=streamer,
**model_inputs,
),
).start()
buffer = ""
reasoning = ""
thinking = False
reasoning_heading = "# Reasoning\n\n"
for new_text in streamer:
if not thinking and "<think>" in new_text:
thinking = True
continue
if thinking and "</think>" in new_text:
thinking = False
continue
if thinking:
reasoning += new_text
yield (
"I'm thinking, please wait a moment...",
reasoning_heading + reasoning,
)
continue
buffer += new_text
yield reformat_math(buffer), reasoning_heading + reasoning
chat_bot = gr.Chatbot(
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
],
scale=1,
type="messages",
)
with gr.Blocks(
theme="davehornik/Tealy",
js=JS,
css=CSS,
fill_height=True,
title="Reasoning model example",
) as demo:
reasoning = gr.Markdown(
"# Reasoning\n\nWhen the model will reasoning, its thoughts will be displayed here.",
label="Reasoning",
show_label=True,
container=True,
elem_classes="m3d-auto-scroll",
render=False,
)
with gr.Row(equal_height=True):
with gr.Column(scale=3, variant="compact"):
gr.ChatInterface(
chat,
type="messages",
chatbot=chat_bot,
title="Simple conversational AI with reasoning",
description=(
f"We're using the **{model_name}**. It is a large language model "
"trained on a mixture of instruction and "
"conversational data. It has the capability to reason about the "
"prompt (the user question). "
"When you ask a question, you can see its thoughts "
"on the left block."
),
additional_outputs=[reasoning],
)
with gr.Column(elem_id="reasoning"):
reasoning.render()
if __name__ == "__main__":
demo.queue().launch()