muryshev commited on
Commit
757dd81
1 Parent(s): 7b04dc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -90
app.py CHANGED
@@ -39,8 +39,8 @@ app.logger.setLevel(logging.DEBUG) # Set the desired logging level
39
  #repo_name = "IlyaGusev/saiga2_13b_gguf"
40
  #model_name = "model-q4_K.gguf"
41
 
42
- repo_name = "IlyaGusev/saiga2_70b_gguf"
43
- model_name = "ggml-model-q4_1.gguf"
44
 
45
  #repo_name = "IlyaGusev/saiga2_7b_gguf"
46
  #model_name = "model-q4_K.gguf"
@@ -55,7 +55,7 @@ model_path = snapshot_download(repo_id=repo_name, allow_patterns=model_name) + '
55
  app.logger.info('Model path: ' + model_path)
56
 
57
  DATASET_REPO_URL = "https://huggingface.co/datasets/muryshev/saiga-chat"
58
- DATA_FILENAME = "data-saiga-cuda.xml"
59
  DATA_FILE = os.path.join("dataset", DATA_FILENAME)
60
 
61
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -190,46 +190,6 @@ def generate_unknown_response():
190
  app.logger.info('payload empty')
191
 
192
  return Response('What do you want?', content_type='text/plain')
193
-
194
- @app.route('/search_request', methods=['POST'])
195
- def generate_search_request():
196
- global stop_generation
197
- stop_generation = True
198
- model.reset()
199
-
200
-
201
- data = request.get_json()
202
- app.logger.info(data)
203
- user_query = data.get("query", "")
204
- preprompt = data.get("preprompt", "")
205
- parameters = data.get("parameters", {})
206
-
207
- # Extract parameters from the request
208
- temperature = parameters.get("temperature", 0.01)
209
- truncate = parameters.get("truncate", 1000)
210
- max_new_tokens = parameters.get("max_new_tokens", 1024)
211
- top_p = parameters.get("top_p", 0.85)
212
- repetition_penalty = parameters.get("repetition_penalty", 1.2)
213
- top_k = parameters.get("top_k", 30)
214
- return_full_text = parameters.get("return_full_text", False)
215
-
216
-
217
-
218
- tokens = get_system_tokens_for_preprompt(model, preprompt)
219
- tokens.append(LINEBREAK_TOKEN)
220
-
221
- tokens = get_message_tokens(model=model, role="user", content=user_query[:200]) + [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
222
- stop_generation = False
223
- generator = model.generate(
224
- tokens,
225
- top_k=top_k,
226
- top_p=top_p,
227
- temp=temperature,
228
- repeat_penalty=repetition_penalty
229
- )
230
-
231
- # Use Response to stream tokens
232
- return Response(generate_tokens(model, generator), content_type='text/plain', status=200, direct_passthrough=True)
233
 
234
  response_tokens = bytearray()
235
  def generate_and_log_tokens(user_request, model, generator):
@@ -245,57 +205,13 @@ def generate_and_log_tokens(user_request, model, generator):
245
  @app.route('/', methods=['POST'])
246
  def generate_response():
247
  global stop_generation
248
- stop_generation = True
249
- model.reset()
250
-
251
- data = request.get_json()
252
- app.logger.info(data)
253
- messages = data.get("messages", [])
254
- preprompt = data.get("preprompt", "")
255
- parameters = data.get("parameters", {})
256
-
257
- # Extract parameters from the request
258
- temperature = parameters.get("temperature", 0.01)
259
- truncate = parameters.get("truncate", 1000)
260
- max_new_tokens = parameters.get("max_new_tokens", 1024)
261
- top_p = parameters.get("top_p", 0.85)
262
- repetition_penalty = parameters.get("repetition_penalty", 1.2)
263
- top_k = parameters.get("top_k", 30)
264
- return_full_text = parameters.get("return_full_text", False)
265
-
266
- tokens = []
267
-
268
- for message in messages:
269
- if message.get("from") == "assistant":
270
- message_tokens = get_message_tokens(model=model, role="bot", content=message.get("content", ""))
271
- elif message.get("from") == "system":
272
- message_tokens = get_message_tokens(model=model, role="system", content=message.get("content", ""))
273
- else:
274
- message_tokens = get_message_tokens(model=model, role="user", content=message.get("content", ""))
275
-
276
- tokens.extend(message_tokens)
277
-
278
- tokens.extend([model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN])
279
-
280
-
281
- app.logger.info('Prompt:')
282
- user_request = model.detokenize(tokens[:CONTEXT_SIZE]).decode("utf-8", errors="ignore")
283
- app.logger.info(user_request)
284
-
285
- stop_generation = False
286
- app.logger.info('Generate started')
287
  generator = model.generate(
288
- tokens[:CONTEXT_SIZE],
289
- top_k=top_k,
290
- top_p=top_p,
291
- temp=temperature,
292
- repeat_penalty=repetition_penalty
293
  )
294
  app.logger.info('Generator created')
295
 
296
-
297
-
298
-
299
  # Use Response to stream tokens
300
  return Response(generate_and_log_tokens(user_request, model, generator), content_type='text/plain', status=200, direct_passthrough=True)
301
 
 
39
  #repo_name = "IlyaGusev/saiga2_13b_gguf"
40
  #model_name = "model-q4_K.gguf"
41
 
42
+ repo_name = "dreamgen/opus-v0-70b-gguf"
43
+ model_name = "dreamgen-opus-v0-70b-Q4_K_M.gguf"
44
 
45
  #repo_name = "IlyaGusev/saiga2_7b_gguf"
46
  #model_name = "model-q4_K.gguf"
 
55
  app.logger.info('Model path: ' + model_path)
56
 
57
  DATASET_REPO_URL = "https://huggingface.co/datasets/muryshev/saiga-chat"
58
+ DATA_FILENAME = "opus-v0-70b.xml"
59
  DATA_FILE = os.path.join("dataset", DATA_FILENAME)
60
 
61
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
190
  app.logger.info('payload empty')
191
 
192
  return Response('What do you want?', content_type='text/plain')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  response_tokens = bytearray()
195
  def generate_and_log_tokens(user_request, model, generator):
 
205
  @app.route('/', methods=['POST'])
206
  def generate_response():
207
  global stop_generation
208
+ raw_content = request.data
209
+ tokens = model.tokenize(raw_content))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  generator = model.generate(
211
+ tokens[:CONTEXT_SIZE]
 
 
 
 
212
  )
213
  app.logger.info('Generator created')
214
 
 
 
 
215
  # Use Response to stream tokens
216
  return Response(generate_and_log_tokens(user_request, model, generator), content_type='text/plain', status=200, direct_passthrough=True)
217