File size: 7,330 Bytes
04c9edc
 
 
 
 
 
43e2aa6
 
 
 
 
 
 
04c9edc
 
 
 
f53e84c
04c9edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43e2aa6
04c9edc
8c30c71
04c9edc
43e2aa6
 
f53e84c
77c7cf9
 
04c9edc
77c7cf9
 
 
 
04c9edc
 
 
77c7cf9
 
5d07c9b
77c7cf9
 
8c30c71
f53e84c
04c9edc
8c30c71
 
04c9edc
f53e84c
 
 
 
04c9edc
 
8c30c71
04c9edc
43e2aa6
04c9edc
 
77c7cf9
 
 
8c30c71
04c9edc
 
77c7cf9
8c30c71
43e2aa6
 
04c9edc
77c7cf9
 
04c9edc
 
 
77c7cf9
43e2aa6
 
04c9edc
f53e84c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04c9edc
 
 
f53e84c
77c7cf9
 
04c9edc
 
 
77c7cf9
 
 
 
 
43e2aa6
77c7cf9
 
04c9edc
 
 
 
 
 
77c7cf9
04c9edc
 
 
77c7cf9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import spaces
import gradio as gr
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread
import torch
import gc

def flush():
    gc.collect()
    torch.cuda.empty_cache()

torch.set_float32_matmul_precision("high")

HF_TOKEN = os.getenv("HF_TOKEN", None)
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
REPO_ID = "nicoboss/DeepSeek-R1-Distill-Qwen-32B-Uncensored"
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

DESCRIPTION = f'''

<div>

<h1 style="text-align: center;">{REPO_ID}</h1>

</div>

'''

PLACEHOLDER = f"""

<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">

   <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">{REPO_ID}</h1>

   <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>

</div>

"""

css = """

h1 {

  text-align: center;

  display: block;

}



#duplicate-button {

  margin: auto;

  color: white;

  background: #1565c0;

  border-radius: 100vh;

}

"""

tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
if torch.cuda.is_available():
    nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
    model = AutoModelForCausalLM.from_pretrained(REPO_ID, device_map="auto", quantization_config=nf4_config)
else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

@spaces.GPU(duration=59)
@torch.inference_mode()
def chat_stream(message: str,

         history: list[dict],

         temperature: float,

         max_new_tokens: int,

         top_p: float,

         top_k: int,

         repetition_penalty: float,

         sys_prompt: str,

         progress=gr.Progress(track_tqdm=True)

        ):
    try:
        messages = []
        response = []
        if not history: history = []
        messages.append({"role": "system", "content": sys_prompt})
        messages.append({"role": "user", "content": message})

        input_tensors = tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, return_dict=True, add_special_tokens=False, return_tensors="pt").to(model.device)

        input_ids = input_tensors["input_ids"]
        attention_mask = input_tensors["attention_mask"]

        #print("history: ", [{"role": x["role"], "content": x["content"]} for x in history if "role" in x.keys()])
        #print("messages: ", [{"role": x["role"], "content": x["content"]} for x in messages if "role" in x.keys()])
        #print("tokenized: ", tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, add_special_tokens=False, tokenize=False))

        generate_kwargs = dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            streamer=streamer,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.eos_token_id,
        )
        if temperature == 0: generate_kwargs['do_sample'] = False
        response.append({"role": "assistant", "content": ""})

        t = Thread(target=model.generate, kwargs=generate_kwargs)
        t.start()
        for text in streamer:
            response[-1]["content"] += text
            yield response
    except Exception as e:
        print(e)
        gr.Warning(f"Error: {e}")
        yield response
    finally:
        flush()

@spaces.GPU(duration=59)
@torch.inference_mode()
def chat(message: str,

         history: list[dict],

         temperature: float,

         max_new_tokens: int,

         top_p: float,

         top_k: int,

         repetition_penalty: float,

         sys_prompt: str,

         progress=gr.Progress(track_tqdm=True)

        ):
    try:
        messages = []
        response = []
        if not history: history = []
        messages.append({"role": "system", "content": sys_prompt})
        messages.append({"role": "user", "content": message})

        input_tensors = tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, return_dict=True, add_special_tokens=False, return_tensors="pt").to(model.device)

        input_ids = input_tensors["input_ids"]
        attention_mask = input_tensors["attention_mask"]

        generate_kwargs = dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.eos_token_id,
        )
        if temperature == 0: generate_kwargs['do_sample'] = False
        response.append({"role": "assistant", "content": ""})

        output_ids = model.generate(**generate_kwargs)
        output = tokenizer.decode(output_ids.tolist()[0][input_ids.size(1) :], skip_special_tokens=True)

        response[-1]["content"] = output
        return response
    except Exception as e:
        print(e)
        gr.Warning(f"Error: {e}")
        return response
    finally:
        flush()

with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
    gr.Markdown(DESCRIPTION)
    gr.ChatInterface(
        fn=chat_stream,
        type="messages",
        chatbot=gr.Chatbot(height=450, type="messages", placeholder=PLACEHOLDER, label='Gradio ChatInterface'),
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", render=False),
            gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens", render=False),
            gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", render=False),
            gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k", render=False),
            gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", render=False),
            gr.Textbox(value="", label="System prompt", render=False)
        ],
        save_history=True,
        examples=[
            ['How to setup a human base on Mars? Give short answer.'],
            ['Explain theory of relativity to me like I’m 8 years old.'],
            ['What is 9,000 * 9,000?'],
            ['Write a pun-filled happy birthday message to my friend Alex.'],
            ['Justify why a penguin might make a good king of the jungle.']
        ],
        cache_examples=False)

if __name__ == "__main__":
    demo.queue().launch(ssr_mode=False)