winglian commited on
Commit
877b0e4
1 Parent(s): c4c5fc5

duplicte code b/c gradio is really wonky

Browse files
Files changed (1) hide show
  1. app.py +27 -3
app.py CHANGED
@@ -304,7 +304,7 @@ def open_user(message, nudge_msg, history):
304
  return "", nudge_msg, history
305
 
306
 
307
- def open_chat(model_name, history, system_msg, max_new_tokens, temperature, top_p, top_k, repetition_penalty, roleplay=False):
308
  history = history or []
309
 
310
  model = get_model_pipeline(model_name)
@@ -332,8 +332,32 @@ def open_chat(model_name, history, system_msg, max_new_tokens, temperature, top_
332
  sleep(0.01)
333
 
334
 
335
- def open_rp_chat(*args):
336
- return open_chat(*args, roleplay=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
 
339
  with gr.Blocks() as arena:
 
304
  return "", nudge_msg, history
305
 
306
 
307
+ def open_chat(model_name, history, system_msg, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
308
  history = history or []
309
 
310
  model = get_model_pipeline(model_name)
 
332
  sleep(0.01)
333
 
334
 
335
+ def open_rp_chat(model_name, history, system_msg, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
336
+ history = history or []
337
+
338
+ model = get_model_pipeline(f"{model_name}-roleplay")
339
+ config = model.get_generation_config()
340
+ config["max_new_tokens"] = max_new_tokens
341
+ config["temperature"] = temperature
342
+ config["temperature"] = temperature
343
+ config["top_p"] = top_p
344
+ config["top_k"] = top_k
345
+ config["repetition_penalty"] = repetition_penalty
346
+
347
+ messages = model.transform_prompt(system_msg, history)
348
+
349
+ # remove last space from assistant, some models output a ZWSP if you leave a space
350
+ messages = messages.rstrip()
351
+
352
+ model_res = model(messages, config=config) # type: Generator[List[Dict[str, str]], None, None]
353
+ for res in model_res:
354
+ tokens = re.findall(r'\s*\S+\s*', res[0]['generated_text'])
355
+ for s in tokens:
356
+ answer = s
357
+ history[-1][1] += answer
358
+ # stream the response
359
+ yield history, history, ""
360
+ sleep(0.01)
361
 
362
 
363
  with gr.Blocks() as arena: