Upload predict.py
Browse files- predict.py +7 -1
predict.py
CHANGED
@@ -333,6 +333,7 @@ def predict(model, text, tokenizer=None,
|
|
333 |
sft=True, convo_template = "",
|
334 |
device = "cuda",
|
335 |
model_name="AquilaChat2-7B",
|
|
|
336 |
**kwargs):
|
337 |
|
338 |
vocab = tokenizer.get_vocab()
|
@@ -352,7 +353,7 @@ def predict(model, text, tokenizer=None,
|
|
352 |
topk = 1
|
353 |
temperature = 1.0
|
354 |
if sft:
|
355 |
-
tokens = covert_prompt_to_input_ids_with_history(text, history=
|
356 |
tokens = torch.tensor(tokens)[None,].to(device)
|
357 |
else :
|
358 |
tokens = tokenizer.encode_plus(text)["input_ids"]
|
@@ -433,4 +434,9 @@ def predict(model, text, tokenizer=None,
|
|
433 |
|
434 |
convert_tokens = convert_tokens[1:]
|
435 |
probs = probs[1:]
|
|
|
|
|
|
|
|
|
|
|
436 |
return out
|
|
|
333 |
sft=True, convo_template = "",
|
334 |
device = "cuda",
|
335 |
model_name="AquilaChat2-7B",
|
336 |
+
history=[],
|
337 |
**kwargs):
|
338 |
|
339 |
vocab = tokenizer.get_vocab()
|
|
|
353 |
topk = 1
|
354 |
temperature = 1.0
|
355 |
if sft:
|
356 |
+
tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=2048, convo_template=convo_template)
|
357 |
tokens = torch.tensor(tokens)[None,].to(device)
|
358 |
else :
|
359 |
tokens = tokenizer.encode_plus(text)["input_ids"]
|
|
|
434 |
|
435 |
convert_tokens = convert_tokens[1:]
|
436 |
probs = probs[1:]
|
437 |
+
|
438 |
+
# Update history
|
439 |
+
history.insert(0, ('USER', text))
|
440 |
+
history.insert(0, ('ASSISTANT', out))
|
441 |
+
|
442 |
return out
|