Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import os | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
from threading import Thread | |
import torch | |
import gc | |
def flush(): | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.set_float32_matmul_precision("high") | |
HF_TOKEN = os.getenv("HF_TOKEN", None) | |
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" | |
REPO_ID = "nicoboss/DeepSeek-R1-Distill-Qwen-32B-Uncensored" | |
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | |
DESCRIPTION = f''' | |
<div> | |
<h1 style="text-align: center;">{REPO_ID}</h1> | |
</div> | |
''' | |
PLACEHOLDER = f""" | |
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> | |
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">{REPO_ID}</h1> | |
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p> | |
</div> | |
""" | |
css = """ | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
#duplicate-button { | |
margin: auto; | |
color: white; | |
background: #1565c0; | |
border-radius: 100vh; | |
} | |
""" | |
tokenizer = AutoTokenizer.from_pretrained(REPO_ID) | |
if torch.cuda.is_available(): | |
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) | |
model = AutoModelForCausalLM.from_pretrained(REPO_ID, device_map="auto", quantization_config=nf4_config) | |
else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
def chat_stream(message: str, | |
history: list[dict], | |
temperature: float, | |
max_new_tokens: int, | |
top_p: float, | |
top_k: int, | |
repetition_penalty: float, | |
sys_prompt: str, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
try: | |
messages = [] | |
response = [] | |
if not history: history = [] | |
messages.append({"role": "system", "content": sys_prompt}) | |
messages.append({"role": "user", "content": message}) | |
input_tensors = tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, return_dict=True, add_special_tokens=False, return_tensors="pt").to(model.device) | |
input_ids = input_tensors["input_ids"] | |
attention_mask = input_tensors["attention_mask"] | |
#print("history: ", [{"role": x["role"], "content": x["content"]} for x in history if "role" in x.keys()]) | |
#print("messages: ", [{"role": x["role"], "content": x["content"]} for x in messages if "role" in x.keys()]) | |
#print("tokenized: ", tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, add_special_tokens=False, tokenize=False)) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
streamer=streamer, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
if temperature == 0: generate_kwargs['do_sample'] = False | |
response.append({"role": "assistant", "content": ""}) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
for text in streamer: | |
response[-1]["content"] += text | |
yield response | |
except Exception as e: | |
print(e) | |
gr.Warning(f"Error: {e}") | |
yield response | |
finally: | |
flush() | |
def chat(message: str, | |
history: list[dict], | |
temperature: float, | |
max_new_tokens: int, | |
top_p: float, | |
top_k: int, | |
repetition_penalty: float, | |
sys_prompt: str, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
try: | |
messages = [] | |
response = [] | |
if not history: history = [] | |
messages.append({"role": "system", "content": sys_prompt}) | |
messages.append({"role": "user", "content": message}) | |
input_tensors = tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, return_dict=True, add_special_tokens=False, return_tensors="pt").to(model.device) | |
input_ids = input_tensors["input_ids"] | |
attention_mask = input_tensors["attention_mask"] | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
if temperature == 0: generate_kwargs['do_sample'] = False | |
response.append({"role": "assistant", "content": ""}) | |
output_ids = model.generate(**generate_kwargs) | |
output = tokenizer.decode(output_ids.tolist()[0][input_ids.size(1) :], skip_special_tokens=True) | |
response[-1]["content"] = output | |
return response | |
except Exception as e: | |
print(e) | |
gr.Warning(f"Error: {e}") | |
return response | |
finally: | |
flush() | |
with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.ChatInterface( | |
fn=chat_stream, | |
type="messages", | |
chatbot=gr.Chatbot(height=450, type="messages", placeholder=PLACEHOLDER, label='Gradio ChatInterface'), | |
fill_height=True, | |
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), | |
additional_inputs=[ | |
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", render=False), | |
gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens", render=False), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", render=False), | |
gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k", render=False), | |
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", render=False), | |
gr.Textbox(value="", label="System prompt", render=False) | |
], | |
save_history=True, | |
examples=[ | |
['How to setup a human base on Mars? Give short answer.'], | |
['Explain theory of relativity to me like I’m 8 years old.'], | |
['What is 9,000 * 9,000?'], | |
['Write a pun-filled happy birthday message to my friend Alex.'], | |
['Justify why a penguin might make a good king of the jungle.'] | |
], | |
cache_examples=False) | |
if __name__ == "__main__": | |
demo.queue().launch(ssr_mode=False) | |