QwQ-Edge / app.py
prithivMLmods's picture
Update app.py
a23a8fc verified
raw
history blame
4.71 kB
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from typing import List, Dict, Optional, Tuple
DESCRIPTION = """
# QwQ Distill
"""
css = '''
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: #fff;
background: #1565c0;
border-radius: 100vh;
}
'''
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.config.sliding_window = 4096
model.eval()
# Set the pad token ID if it's not already set
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Define roles for the chat
class Role:
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
# Default system message
default_system = "You are a helpful assistant."
def clear_session() -> List:
return "", []
def modify_system_session(system: str) -> Tuple[str, str, List]:
if system is None or len(system) == 0:
system = default_system
return system, system, []
def history_to_messages(history: List, system: str) -> List[Dict]:
messages = [{'role': Role.SYSTEM, 'content': system}]
for h in history:
messages.append({'role': Role.USER, 'content': h[0]})
messages.append({'role': Role.ASSISTANT, 'content': h[1]})
return messages
@spaces.GPU(duration=120)
def generate(
query: Optional[str],
history: Optional[List],
system: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
if query is None:
query = ''
if history is None:
history = []
# Convert history to messages
messages = history_to_messages(history, system)
messages.append({'role': Role.USER, 'content': query})
# Apply chat template and tokenize
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Set up the streamer for real-time text generation
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
**model_inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Stream the output tokens
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System Message", value=default_system, lines=2),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Write a Python function to reverses a string if it's length is a multiple of 4."],
["What is the volume of a pyramid with a rectangular base?"],
["Explain the difference between List comprehension and Lambda in Python."],
["What happens when the sun goes down?"],
],
cache_examples=False,
description=DESCRIPTION,
css=css,
fill_height=True,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)