DFofanov78 commited on
Commit
12420cf
1 Parent(s): a40e9ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -50
app.py CHANGED
@@ -1,63 +1,199 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
 
 
 
 
 
 
 
 
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
 
 
 
 
 
 
 
 
 
 
 
35
  top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
 
1
  import gradio as gr
2
+ import os
3
 
4
+ from huggingface_hub.file_download import http_get
5
+ from llama_cpp import Llama
 
 
6
 
7
+ SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
8
 
9
+ def get_message_tokens(model, role, content):
10
+ content = f"{role}\n{content}\n</s>"
11
+ content = content.encode("utf-8")
12
+ return model.tokenize(content, special=True)
13
+
14
+
15
+ def get_system_tokens(model):
16
+ system_message = {"role": "system", "content": SYSTEM_PROMPT}
17
+ return get_message_tokens(model, **system_message)
18
+
19
+
20
+ def load_model(
21
+ directory: str = ".",
22
+ model_name: str = "RKF-v1-8b-Instruct-q4_k_m-gguf-unsloth.Q4_K_M.gguf",
23
+ model_url: str = "https://huggingface.co/DFofanov78/RKF-v1-8b-Instruct-q4_k_m-gguf/resolve/main/RKF-v1-8b-Instruct-q4_k_m-gguf-unsloth.Q4_K_M.gguf"
24
  ):
25
+ final_model_path = os.path.join(directory, model_name)
26
+
27
+ print("Downloading all files...")
28
+ if not os.path.exists(final_model_path):
29
+ with open(final_model_path, "wb") as f:
30
+ http_get(model_url, f)
31
+ os.chmod(final_model_path, 0o777)
32
+ print("Files downloaded!")
33
+
34
+ model = Llama(
35
+ model_path=final_model_path,
36
+ n_ctx=1024
37
+ )
38
+
39
+ print("Model loaded!")
40
+ return model
41
 
 
 
 
 
 
42
 
43
+ MODEL = load_model()
44
 
45
+ def user(message, history):
46
+ new_history = history + [[message, None]]
47
+ return "", new_history
48
+
49
+
50
+ def bot(
51
+ history,
52
+ system_prompt,
53
+ top_p,
54
+ top_k,
55
+ temp
56
+ ):
57
+ model = MODEL
58
+ tokens = get_system_tokens(model)[:]
59
 
60
+ for user_message, bot_message in history[:-1]:
61
+ message_tokens = get_message_tokens(model=model, role="user", content=user_message)
62
+ tokens.extend(message_tokens)
63
+ if bot_message:
64
+ message_tokens = get_message_tokens(model=model, role="bot", content=bot_message)
65
+ tokens.extend(message_tokens)
66
+
67
+ last_user_message = history[-1][0]
68
+ message_tokens = get_message_tokens(model=model, role="user", content=last_user_message)
69
+ tokens.extend(message_tokens)
70
+
71
+ role_tokens = model.tokenize("bot\n".encode("utf-8"), special=True)
72
+ tokens.extend(role_tokens)
73
+ generator = model.generate(
74
+ tokens,
75
+ top_k=top_k,
76
  top_p=top_p,
77
+ temp=temp
78
+ )
79
+
80
+ partial_text = ""
81
+ for i, token in enumerate(generator):
82
+ if token == model.token_eos():
83
+ break
84
+ partial_text += model.detokenize([token]).decode("utf-8", "ignore")
85
+ history[-1][1] = partial_text
86
+ yield history
87
+
88
+
89
+ with gr.Blocks(
90
+ theme=gr.themes.Soft()
91
+ ) as demo:
92
+ favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">'
93
+ gr.Markdown(
94
+ f"""<h1><center>{favicon}Saiga2 13B GGUF Q4_K</center></h1>
95
+ This is a demo of a **Russian**-speaking LLaMA2-based model. If you are interested in other languages, please check other models, such as [MPT-7B-Chat](https://huggingface.co/spaces/mosaicml/mpt-7b-chat).
96
+ Это демонстрационная версия [квантованной Сайги-2 с 13 миллиардами параметров](https://huggingface.co/IlyaGusev/saiga2_13b_ggml), работающая на CPU.
97
+ Сайга-2 — это разговорная языковая модель, которая основана на [LLaMA-2](https://ai.meta.com/llama/) и дообучена на корпусах, сгенерированных ChatGPT, таких как [ru_turbo_alpaca](https://huggingface.co/datasets/IlyaGusev/ru_turbo_alpaca), [ru_turbo_saiga](https://huggingface.co/datasets/IlyaGusev/ru_turbo_saiga) и [gpt_roleplay_realm](https://huggingface.co/datasets/IlyaGusev/gpt_roleplay_realm).
98
+ """
99
+ )
100
+ with gr.Row():
101
+ with gr.Column(scale=5):
102
+ system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT, interactive=False)
103
+ chatbot = gr.Chatbot(label="Диалог")
104
+ with gr.Column(min_width=80, scale=1):
105
+ with gr.Tab(label="Параметры генерации"):
106
+ top_p = gr.Slider(
107
+ minimum=0.0,
108
+ maximum=1.0,
109
+ value=0.9,
110
+ step=0.05,
111
+ interactive=True,
112
+ label="Top-p",
113
+ )
114
+ top_k = gr.Slider(
115
+ minimum=10,
116
+ maximum=100,
117
+ value=30,
118
+ step=5,
119
+ interactive=True,
120
+ label="Top-k",
121
+ )
122
+ temp = gr.Slider(
123
+ minimum=0.0,
124
+ maximum=2.0,
125
+ value=0.01,
126
+ step=0.01,
127
+ interactive=True,
128
+ label="Температура"
129
+ )
130
+ with gr.Row():
131
+ with gr.Column():
132
+ msg = gr.Textbox(
133
+ label="Отправить сообщение",
134
+ placeholder="Отправить сообщение",
135
+ show_label=False,
136
+ )
137
+ with gr.Column():
138
+ with gr.Row():
139
+ submit = gr.Button("Отправить")
140
+ stop = gr.Button("Остановить")
141
+ clear = gr.Button("Очистить")
142
+ with gr.Row():
143
+ gr.Markdown(
144
+ """ПРЕДУПРЕЖДЕНИЕ: Модель может генерировать фактически или этически некорректные тексты. Мы не несём за это ответственность."""
145
+ )
146
+
147
+ # Pressing Enter
148
+ submit_event = msg.submit(
149
+ fn=user,
150
+ inputs=[msg, chatbot],
151
+ outputs=[msg, chatbot],
152
+ queue=False,
153
+ ).success(
154
+ fn=bot,
155
+ inputs=[
156
+ chatbot,
157
+ system_prompt,
158
+ top_p,
159
+ top_k,
160
+ temp
161
+ ],
162
+ outputs=chatbot,
163
+ queue=True,
164
+ )
165
+
166
+ # Pressing the button
167
+ submit_click_event = submit.click(
168
+ fn=user,
169
+ inputs=[msg, chatbot],
170
+ outputs=[msg, chatbot],
171
+ queue=False,
172
+ ).success(
173
+ fn=bot,
174
+ inputs=[
175
+ chatbot,
176
+ system_prompt,
177
+ top_p,
178
+ top_k,
179
+ temp
180
+ ],
181
+ outputs=chatbot,
182
+ queue=True,
183
+ )
184
+
185
+ # Stop generation
186
+ stop.click(
187
+ fn=None,
188
+ inputs=None,
189
+ outputs=None,
190
+ cancels=[submit_event, submit_click_event],
191
+ queue=False,
192
+ )
193
 
194
+ # Clear history
195
+ clear.click(lambda: None, None, chatbot, queue=False)
196
 
197
  if __name__ == "__main__":
198
+ demo.queue(max_size=128)
199
+ demo.launch(show_error=True)