muryshev commited on
Commit
377a7af
1 Parent(s): ff6ee2c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, Response
2
+ import logging
3
+ from llama_cpp import Llama
4
+ import threading
5
+ from huggingface_hub import snapshot_download
6
+
7
+ SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя."
8
+ SYSTEM_TOKEN = 1788
9
+ USER_TOKEN = 1404
10
+ BOT_TOKEN = 9225
11
+ LINEBREAK_TOKEN = 13
12
+
13
+ ROLE_TOKENS = {
14
+ "user": USER_TOKEN,
15
+ "bot": BOT_TOKEN,
16
+ "system": SYSTEM_TOKEN
17
+ }
18
+
19
+ # Create a lock object
20
+ lock = threading.Lock()
21
+
22
+ app = Flask(__name__)
23
+ # Configure Flask logging
24
+ app.logger.setLevel(logging.DEBUG) # Set the desired logging level
25
+
26
+ # Initialize the model when the application starts
27
+ #model_path = "../models/model-q4_K.gguf" # Replace with the actual model path
28
+ #model_name = "model/ggml-model-q4_K.gguf"
29
+
30
+ repo_name = "IlyaGusev/saiga2_13b_gguf"
31
+ model_name = "model-q4_K.gguf"
32
+
33
+ snapshot_download(repo_id=repo_name, local_dir=".", allow_patterns=model_name)
34
+
35
+ model = Llama(
36
+ model_path=model_name,
37
+ n_ctx=2000,
38
+ n_parts=1,
39
+ #n_batch=100,
40
+ logits_all=True,
41
+ #n_threads=12,
42
+ verbose=True
43
+ )
44
+
45
+
46
+ def get_message_tokens(model, role, content):
47
+ message_tokens = model.tokenize(content.encode("utf-8"))
48
+ message_tokens.insert(1, ROLE_TOKENS[role])
49
+ message_tokens.insert(2, LINEBREAK_TOKEN)
50
+ message_tokens.append(model.token_eos())
51
+ return message_tokens
52
+
53
+ def get_system_tokens(model):
54
+ system_message = {
55
+ "role": "system",
56
+ "content": SYSTEM_PROMPT
57
+ }
58
+ return get_message_tokens(model, **system_message)
59
+
60
+ def get_system_tokens_for_preprompt(model, preprompt):
61
+ system_message = {
62
+ "role": "system",
63
+ "content": preprompt
64
+ }
65
+ return get_message_tokens(model, **system_message)
66
+
67
+ app.logger.info('Evaluating system tokens start')
68
+ #system_tokens = get_system_tokens(model)
69
+ #model.eval(system_tokens)
70
+ app.logger.info('Evaluating system tokens end')
71
+
72
+ stop_generation = False
73
+
74
+ def generate_tokens(model, generator):
75
+ global stop_generation
76
+ app.logger.info('generate_tokens started')
77
+ #with lock:
78
+ for token in generator:
79
+ if token == model.token_eos() or stop_generation:
80
+ stop_generation = False
81
+ yield b'' # End of chunk
82
+ break
83
+
84
+ token_str = model.detokenize([token])#.decode("utf-8", errors="ignore")
85
+ yield token_str
86
+
87
+ @app.route('/stop_generation', methods=['GET'])
88
+ def handler_stop_generation():
89
+ global stop_generation
90
+ stop_generation = True
91
+ return Response('Stopped', content_type='text/plain')
92
+
93
+ @app.route('/', methods=['GET', 'PUT', 'DELETE', 'PATCH'])
94
+ def generate_unknown_response():
95
+ app.logger.info('unknown method: '+request.method)
96
+ try:
97
+ request_payload = request.get_json()
98
+ app.logger.info('payload: '+request.get_json())
99
+ except Exception as e:
100
+ app.logger.info('payload empty')
101
+
102
+ return Response('What do you want?', content_type='text/plain')
103
+
104
+ @app.route('/search_request', methods=['POST'])
105
+ def generate_search_request():
106
+ global stop_generation
107
+ stop_generation = False
108
+ data = request.get_json()
109
+ app.logger.info(data)
110
+ user_query = data.get("query", "")
111
+ preprompt = data.get("preprompt", "Ты — русскоязычный автоматический ассистент для написании запросов для поисковых систем. Отвечай на сообщения пользователя только текстом поискового запроса, релевантным запросу пользователя. Если запрос пользователя уже хорош, используй его в качестве результата.")
112
+ parameters = data.get("parameters", {})
113
+
114
+ # Extract parameters from the request
115
+ temperature = 0.01
116
+ truncate = parameters.get("truncate", 1000)
117
+ max_new_tokens = parameters.get("max_new_tokens", 1024)
118
+ top_p = 0.8
119
+ repetition_penalty = parameters.get("repetition_penalty", 1.2)
120
+ top_k = 20
121
+ return_full_text = parameters.get("return_full_text", False)
122
+
123
+ tokens = get_system_tokens_for_preprompt(model, preprompt)
124
+ tokens.append(LINEBREAK_TOKEN)
125
+
126
+ tokens = get_message_tokens(model=model, role="user", content=user_query[:200]) + [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
127
+
128
+ generator = model.generate(
129
+ tokens,
130
+ top_k=top_k,
131
+ top_p=top_p,
132
+ temp=temperature,
133
+ repeat_penalty=repetition_penalty
134
+ )
135
+
136
+ # Use Response to stream tokens
137
+ return Response(generate_tokens(model, generator), content_type='text/plain', status=200, direct_passthrough=True)
138
+
139
+ @app.route('/', methods=['POST'])
140
+ def generate_response():
141
+ global stop_generation
142
+ stop_generation = False
143
+
144
+ data = request.get_json()
145
+ app.logger.info(data)
146
+ messages = data.get("messages", [])
147
+ preprompt = data.get("preprompt", "")
148
+ parameters = data.get("parameters", {})
149
+
150
+ # Extract parameters from the request
151
+ temperature = 0.02#parameters.get("temperature", 0.01)
152
+ truncate = parameters.get("truncate", 1000)
153
+ max_new_tokens = parameters.get("max_new_tokens", 1024)
154
+ top_p = 80#parameters.get("top_p", 0.85)
155
+ repetition_penalty = parameters.get("repetition_penalty", 1.2)
156
+ top_k = 25#parameters.get("top_k", 30)
157
+ return_full_text = parameters.get("return_full_text", False)
158
+
159
+
160
+
161
+ # Generate the response
162
+ #system_tokens = get_system_tokens(model)
163
+ #tokens = system_tokens
164
+
165
+ #if preprompt != "":
166
+ # tokens = get_system_tokens_for_preprompt(model, preprompt)
167
+ #else:
168
+ tokens = get_system_tokens(model)
169
+ tokens.append(LINEBREAK_TOKEN)
170
+ #model.eval(tokens)
171
+
172
+
173
+ tokens = []
174
+
175
+ for message in messages:#[:-1]:
176
+ if message.get("from") == "assistant":
177
+ message_tokens = get_message_tokens(model=model, role="bot", content=message.get("content", ""))
178
+ else:
179
+ message_tokens = get_message_tokens(model=model, role="user", content=message.get("content", ""))
180
+
181
+ tokens.extend(message_tokens)
182
+ #LINEBREAK_TOKEN)
183
+
184
+ #app.logger.info('model.eval start')
185
+ #model.eval(tokens)
186
+ #app.logger.info('model.eval end')
187
+
188
+ #last_message = messages[-1]
189
+ #if last_message.get("from") == "assistant":
190
+ # last_message_tokens = get_message_tokens(model=model, role="bot", content=last_message.get("content", ""))
191
+ #else:
192
+ # last_message_tokens = get_message_tokens(model=model, role="user", content=last_message.get("content", ""))
193
+
194
+ tokens.extend([model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN])
195
+
196
+ app.logger.info('Prompt:')
197
+ app.logger.info(model.detokenize(tokens).decode("utf-8", errors="ignore"))
198
+
199
+ app.logger.info('Generate started')
200
+ generator = model.generate(
201
+ tokens,
202
+ top_k=top_k,
203
+ top_p=top_p,
204
+ temp=temperature,
205
+ repeat_penalty=repetition_penalty
206
+ )
207
+ app.logger.info('Generator created')
208
+
209
+ # Use Response to stream tokens
210
+ return Response(generate_tokens(model, generator), content_type='text/plain', status=200, direct_passthrough=True)
211
+
212
+ if __name__ == "__main__":
213
+ app.run(host="0.0.0.0", port=7860, debug=False)#, threaded=False)