Spaces:
Runtime error
Runtime error
File size: 7,330 Bytes
04c9edc 43e2aa6 04c9edc f53e84c 04c9edc 43e2aa6 04c9edc 8c30c71 04c9edc 43e2aa6 f53e84c 77c7cf9 04c9edc 77c7cf9 04c9edc 77c7cf9 5d07c9b 77c7cf9 8c30c71 f53e84c 04c9edc 8c30c71 04c9edc f53e84c 04c9edc 8c30c71 04c9edc 43e2aa6 04c9edc 77c7cf9 8c30c71 04c9edc 77c7cf9 8c30c71 43e2aa6 04c9edc 77c7cf9 04c9edc 77c7cf9 43e2aa6 04c9edc f53e84c 04c9edc f53e84c 77c7cf9 04c9edc 77c7cf9 43e2aa6 77c7cf9 04c9edc 77c7cf9 04c9edc 77c7cf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import spaces
import gradio as gr
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("high")
HF_TOKEN = os.getenv("HF_TOKEN", None)
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
REPO_ID = "nicoboss/DeepSeek-R1-Distill-Qwen-32B-Uncensored"
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
DESCRIPTION = f'''
<div>
<h1 style="text-align: center;">{REPO_ID}</h1>
</div>
'''
PLACEHOLDER = f"""
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">{REPO_ID}</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
if torch.cuda.is_available():
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(REPO_ID, device_map="auto", quantization_config=nf4_config)
else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
@spaces.GPU(duration=59)
@torch.inference_mode()
def chat_stream(message: str,
history: list[dict],
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
repetition_penalty: float,
sys_prompt: str,
progress=gr.Progress(track_tqdm=True)
):
try:
messages = []
response = []
if not history: history = []
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": message})
input_tensors = tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, return_dict=True, add_special_tokens=False, return_tensors="pt").to(model.device)
input_ids = input_tensors["input_ids"]
attention_mask = input_tensors["attention_mask"]
#print("history: ", [{"role": x["role"], "content": x["content"]} for x in history if "role" in x.keys()])
#print("messages: ", [{"role": x["role"], "content": x["content"]} for x in messages if "role" in x.keys()])
#print("tokenized: ", tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, add_special_tokens=False, tokenize=False))
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
streamer=streamer,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
)
if temperature == 0: generate_kwargs['do_sample'] = False
response.append({"role": "assistant", "content": ""})
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for text in streamer:
response[-1]["content"] += text
yield response
except Exception as e:
print(e)
gr.Warning(f"Error: {e}")
yield response
finally:
flush()
@spaces.GPU(duration=59)
@torch.inference_mode()
def chat(message: str,
history: list[dict],
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
repetition_penalty: float,
sys_prompt: str,
progress=gr.Progress(track_tqdm=True)
):
try:
messages = []
response = []
if not history: history = []
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": message})
input_tensors = tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, return_dict=True, add_special_tokens=False, return_tensors="pt").to(model.device)
input_ids = input_tensors["input_ids"]
attention_mask = input_tensors["attention_mask"]
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
)
if temperature == 0: generate_kwargs['do_sample'] = False
response.append({"role": "assistant", "content": ""})
output_ids = model.generate(**generate_kwargs)
output = tokenizer.decode(output_ids.tolist()[0][input_ids.size(1) :], skip_special_tokens=True)
response[-1]["content"] = output
return response
except Exception as e:
print(e)
gr.Warning(f"Error: {e}")
return response
finally:
flush()
with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=chat_stream,
type="messages",
chatbot=gr.Chatbot(height=450, type="messages", placeholder=PLACEHOLDER, label='Gradio ChatInterface'),
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens", render=False),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", render=False),
gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k", render=False),
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", render=False),
gr.Textbox(value="", label="System prompt", render=False)
],
save_history=True,
examples=[
['How to setup a human base on Mars? Give short answer.'],
['Explain theory of relativity to me like I’m 8 years old.'],
['What is 9,000 * 9,000?'],
['Write a pun-filled happy birthday message to my friend Alex.'],
['Justify why a penguin might make a good king of the jungle.']
],
cache_examples=False)
if __name__ == "__main__":
demo.queue().launch(ssr_mode=False)
|