chatbot-zero / app.py
John6666's picture
Upload 2 files
f53e84c verified
import spaces
import gradio as gr
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("high")
HF_TOKEN = os.getenv("HF_TOKEN", None)
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
REPO_ID = "nicoboss/DeepSeek-R1-Distill-Qwen-32B-Uncensored"
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
DESCRIPTION = f'''
<div>
<h1 style="text-align: center;">{REPO_ID}</h1>
</div>
'''
PLACEHOLDER = f"""
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">{REPO_ID}</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
if torch.cuda.is_available():
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(REPO_ID, device_map="auto", quantization_config=nf4_config)
else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
@spaces.GPU(duration=59)
@torch.inference_mode()
def chat_stream(message: str,
history: list[dict],
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
repetition_penalty: float,
sys_prompt: str,
progress=gr.Progress(track_tqdm=True)
):
try:
messages = []
response = []
if not history: history = []
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": message})
input_tensors = tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, return_dict=True, add_special_tokens=False, return_tensors="pt").to(model.device)
input_ids = input_tensors["input_ids"]
attention_mask = input_tensors["attention_mask"]
#print("history: ", [{"role": x["role"], "content": x["content"]} for x in history if "role" in x.keys()])
#print("messages: ", [{"role": x["role"], "content": x["content"]} for x in messages if "role" in x.keys()])
#print("tokenized: ", tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, add_special_tokens=False, tokenize=False))
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
streamer=streamer,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
)
if temperature == 0: generate_kwargs['do_sample'] = False
response.append({"role": "assistant", "content": ""})
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for text in streamer:
response[-1]["content"] += text
yield response
except Exception as e:
print(e)
gr.Warning(f"Error: {e}")
yield response
finally:
flush()
@spaces.GPU(duration=59)
@torch.inference_mode()
def chat(message: str,
history: list[dict],
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
repetition_penalty: float,
sys_prompt: str,
progress=gr.Progress(track_tqdm=True)
):
try:
messages = []
response = []
if not history: history = []
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": message})
input_tensors = tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, return_dict=True, add_special_tokens=False, return_tensors="pt").to(model.device)
input_ids = input_tensors["input_ids"]
attention_mask = input_tensors["attention_mask"]
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
)
if temperature == 0: generate_kwargs['do_sample'] = False
response.append({"role": "assistant", "content": ""})
output_ids = model.generate(**generate_kwargs)
output = tokenizer.decode(output_ids.tolist()[0][input_ids.size(1) :], skip_special_tokens=True)
response[-1]["content"] = output
return response
except Exception as e:
print(e)
gr.Warning(f"Error: {e}")
return response
finally:
flush()
with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=chat_stream,
type="messages",
chatbot=gr.Chatbot(height=450, type="messages", placeholder=PLACEHOLDER, label='Gradio ChatInterface'),
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens", render=False),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", render=False),
gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k", render=False),
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", render=False),
gr.Textbox(value="", label="System prompt", render=False)
],
save_history=True,
examples=[
['How to setup a human base on Mars? Give short answer.'],
['Explain theory of relativity to me like I’m 8 years old.'],
['What is 9,000 * 9,000?'],
['Write a pun-filled happy birthday message to my friend Alex.'],
['Justify why a penguin might make a good king of the jungle.']
],
cache_examples=False)
if __name__ == "__main__":
demo.queue().launch(ssr_mode=False)