BAAI
/

ldwang commited on
Commit
380ba89
1 Parent(s): a3be330

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +7 -1
predict.py CHANGED
@@ -333,6 +333,7 @@ def predict(model, text, tokenizer=None,
333
  sft=True, convo_template = "",
334
  device = "cuda",
335
  model_name="AquilaChat2-7B",
 
336
  **kwargs):
337
 
338
  vocab = tokenizer.get_vocab()
@@ -352,7 +353,7 @@ def predict(model, text, tokenizer=None,
352
  topk = 1
353
  temperature = 1.0
354
  if sft:
355
- tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=2048, convo_template=convo_template)
356
  tokens = torch.tensor(tokens)[None,].to(device)
357
  else :
358
  tokens = tokenizer.encode_plus(text)["input_ids"]
@@ -433,4 +434,9 @@ def predict(model, text, tokenizer=None,
433
 
434
  convert_tokens = convert_tokens[1:]
435
  probs = probs[1:]
 
 
 
 
 
436
  return out
 
333
  sft=True, convo_template = "",
334
  device = "cuda",
335
  model_name="AquilaChat2-7B",
336
+ history=[],
337
  **kwargs):
338
 
339
  vocab = tokenizer.get_vocab()
 
353
  topk = 1
354
  temperature = 1.0
355
  if sft:
356
+ tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=2048, convo_template=convo_template)
357
  tokens = torch.tensor(tokens)[None,].to(device)
358
  else :
359
  tokens = tokenizer.encode_plus(text)["input_ids"]
 
434
 
435
  convert_tokens = convert_tokens[1:]
436
  probs = probs[1:]
437
+
438
+ # Update history
439
+ history.insert(0, ('USER', text))
440
+ history.insert(0, ('ASSISTANT', out))
441
+
442
  return out