IlyaGusev commited on
Commit
4a27403
1 Parent(s): cf71dc4

Create new file

Browse files
Files changed (1) hide show
  1. app.py +220 -0
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import copy
4
+ import random
5
+ import os
6
+ import requests
7
+ import time
8
+ import sys
9
+
10
+ from huggingface_hub import snapshot_download
11
+ from llama_cpp import Llama
12
+
13
+
14
+ SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
15
+ SYSTEM_TOKEN = 1788
16
+ USER_TOKEN = 1404
17
+ BOT_TOKEN = 9225
18
+ LINEBREAK_TOKEN = 13
19
+
20
+
21
+ def get_message_tokens(model, role, content):
22
+ message_tokens = model.tokenize(content.encode("utf-8"))
23
+ message_tokens.insert(1, ROLE_TOKENS[role])
24
+ message_tokens.insert(2, LINEBREAK_TOKEN)
25
+ message_tokens.append(model.token_eos())
26
+ return message_tokens
27
+
28
+
29
+ def get_system_tokens(model):
30
+ system_message = {"role": "system", "content": SYSTEM_PROMPT}
31
+ return get_message_tokens(model, **system_message)
32
+
33
+
34
+ repo_name = "IlyaGusev/saiga2_13b_ggml"
35
+ model_name = "ggml-model-q4_1.bin"
36
+
37
+ snapshot_download(repo_id=repo_name, local_dir=".", allow_patterns=model_name)
38
+
39
+ model = Llama(
40
+ model_path=model_name,
41
+ n_ctx=2000,
42
+ n_parts=1,
43
+ )
44
+
45
+ max_new_tokens = 1500
46
+
47
+ def user(message, history):
48
+ new_history = history + [[message, None]]
49
+ return "", new_history
50
+
51
+
52
+ def bot(
53
+ history,
54
+ system_prompt,
55
+ top_p,
56
+ top_k,
57
+ temp
58
+ )
59
+ tokens = get_system_tokens(model)[:]
60
+ tokens.append(LINEBREAK_TOKEN)
61
+
62
+ for user_message, bot_message in history[:-1]:
63
+ message_tokens = get_message_tokens(model=model, role="user", content=user_message)
64
+ tokens.extend(message_tokens)
65
+ if bot_message:
66
+ message_tokens = get_message_tokens(model=model, role="bot", content=bot_message)
67
+ tokens.extend(message_tokens)
68
+
69
+ last_user_message = history[-1][0]
70
+ if retrieved_docs:
71
+ last_user_message = f"Контекст: {retrieved_docs}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}"
72
+ message_tokens = get_message_tokens(model=model, role="user", content=last_user_message)
73
+ tokens.extend(message_tokens)
74
+
75
+ role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
76
+ tokens.extend(role_tokens)
77
+ generator = model.generate(
78
+ tokens,
79
+ top_k=top_k,
80
+ top_p=top_p,
81
+ temp=temp
82
+ )
83
+
84
+ partial_text = ""
85
+ for i, token in enumerate(generator):
86
+ if token == model.token_eos() or (max_new_tokens is not None and i >= max_new_tokens):
87
+ break
88
+ partial_text += model.detokenize([token]).decode("utf-8", "ignore")
89
+ history[-1][1] = partial_text
90
+ yield history
91
+
92
+
93
+ with gr.Blocks(
94
+ theme=gr.themes.Soft()
95
+ ) as demo:
96
+ conversation_id = gr.State(get_uuid)
97
+ favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">'
98
+ gr.Markdown(
99
+ f"""<h1><center>{favicon}Saiga2 13B</center></h1>
100
+
101
+ This is a demo of a **Russian**-speaking LLaMA2-based model. If you are interested in other languages, please check other models, such as [MPT-7B-Chat](https://huggingface.co/spaces/mosaicml/mpt-7b-chat).
102
+
103
+ Это демонстрационная версия версии [Сайги-2 с 13 миллиардами параметров](https://huggingface.co/IlyaGusev/saiga_13b_lora).
104
+
105
+ Сайга — это разговорная языковая модель, которая основана на [LLaMA](https://research.facebook.com/publications/llama-open-and-efficient-foundation-language-models/) и дообучена на корпусах, сгенерированных ChatGPT, таких как [ru_turbo_alpaca](https://huggingface.co/datasets/IlyaGusev/ru_turbo_alpaca), [ru_turbo_saiga](https://huggingface.co/datasets/IlyaGusev/ru_turbo_saiga) и [gpt_roleplay_realm](https://huggingface.co/datasets/IlyaGusev/gpt_roleplay_realm).
106
+ """
107
+ )
108
+ with gr.Row():
109
+ with gr.Column(scale=5):
110
+ system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT, interactive=False)
111
+ chatbot = gr.Chatbot(label="Диалог").style(height=400)
112
+ with gr.Column(min_width=80, scale=1):
113
+ with gr.Tab(label="Параметры генерации"):
114
+ top_p = gr.Slider(
115
+ minimum=0.0,
116
+ maximum=1.0,
117
+ value=0.9,
118
+ step=0.05,
119
+ interactive=True,
120
+ label="Top-p",
121
+ )
122
+ top_k = gr.Slider(
123
+ minimum=10,
124
+ maximum=100,
125
+ value=30,
126
+ step=5,
127
+ interactive=True,
128
+ label="Top-k",
129
+ )
130
+ temp = gr.Slider(
131
+ minimum=0.0,
132
+ maximum=2.0,
133
+ value=0.1,
134
+ step=0.1,
135
+ interactive=True,
136
+ label="Temp"
137
+ )
138
+ with gr.Row():
139
+ with gr.Column():
140
+ msg = gr.Textbox(
141
+ label="Отправить сообщение",
142
+ placeholder="Отправить сообщение",
143
+ show_label=False,
144
+ ).style(container=False)
145
+ with gr.Column():
146
+ with gr.Row():
147
+ submit = gr.Button("Отправить")
148
+ stop = gr.Button("Остановить")
149
+ clear = gr.Button("Очистить")
150
+ with gr.Row():
151
+ gr.Markdown(
152
+ """ПРЕДУПРЕЖДЕНИЕ: Модель может генерировать фактически или этически некорректные тексты. Мы не несём за это ответственность."""
153
+ )
154
+
155
+ # Pressing Enter
156
+ submit_event = msg.submit(
157
+ fn=user,
158
+ inputs=[msg, chatbot, system_prompt],
159
+ outputs=[msg, chatbot],
160
+ queue=False,
161
+ ).success(
162
+ fn=retrieve,
163
+ inputs=[chatbot, db, retrieved_docs, k_documents],
164
+ outputs=[retrieved_docs],
165
+ queue=True,
166
+ ).success(
167
+ fn=bot,
168
+ inputs=[
169
+ chatbot,
170
+ system_prompt,
171
+ conversation_id,
172
+ retrieved_docs,
173
+ top_p,
174
+ top_k,
175
+ temp
176
+ ],
177
+ outputs=chatbot,
178
+ queue=True,
179
+ )
180
+
181
+ # Pressing the button
182
+ submit_click_event = submit.click(
183
+ fn=user,
184
+ inputs=[msg, chatbot, system_prompt],
185
+ outputs=[msg, chatbot],
186
+ queue=False,
187
+ ).success(
188
+ fn=retrieve,
189
+ inputs=[chatbot, db, retrieved_docs, k_documents],
190
+ outputs=[retrieved_docs],
191
+ queue=True,
192
+ ).success(
193
+ fn=bot,
194
+ inputs=[
195
+ chatbot,
196
+ system_prompt,
197
+ conversation_id,
198
+ retrieved_docs,
199
+ top_p,
200
+ top_k,
201
+ temp
202
+ ],
203
+ outputs=chatbot,
204
+ queue=True,
205
+ )
206
+
207
+ # Stop generation
208
+ stop.click(
209
+ fn=None,
210
+ inputs=None,
211
+ outputs=None,
212
+ cancels=[submit_event, submit_click_event],
213
+ queue=False,
214
+ )
215
+
216
+ # Clear history
217
+ clear.click(lambda: None, None, chatbot, queue=False)
218
+
219
+ demo.queue(max_size=128, concurrency_count=1)
220
+ demo.launch()