Update model.py
Browse files
model.py
CHANGED
@@ -68,30 +68,26 @@ def run(message: str,
|
|
68 |
max_new_tokens: int = 256,
|
69 |
temperature: float = 0.8,
|
70 |
top_p: float = 0.95,
|
71 |
-
top_k: int = 50) ->
|
72 |
prompt = get_prompt(message, chat_history, system_prompt)
|
73 |
inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
inputs,
|
81 |
-
streamer=streamer,
|
82 |
-
max_new_tokens=max_new_tokens,
|
83 |
do_sample=True,
|
84 |
top_p=top_p,
|
85 |
top_k=top_k,
|
86 |
temperature=temperature,
|
87 |
-
num_beams=1
|
88 |
)
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
yield ''.join(outputs)
|
|
|
68 |
max_new_tokens: int = 256,
|
69 |
temperature: float = 0.8,
|
70 |
top_p: float = 0.95,
|
71 |
+
top_k: int = 50) -> str:
|
72 |
prompt = get_prompt(message, chat_history, system_prompt)
|
73 |
inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)
|
74 |
|
75 |
+
# Generate tokens using the model
|
76 |
+
output = model.generate(
|
77 |
+
input_ids=inputs['input_ids'],
|
78 |
+
attention_mask=inputs['attention_mask'],
|
79 |
+
max_length=max_new_tokens + inputs['input_ids'].shape[-1],
|
|
|
|
|
|
|
80 |
do_sample=True,
|
81 |
top_p=top_p,
|
82 |
top_k=top_k,
|
83 |
temperature=temperature,
|
84 |
+
num_beams=1
|
85 |
)
|
86 |
+
|
87 |
+
# Decode the output tokens back to a string
|
88 |
+
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
89 |
+
|
90 |
+
# Remove everything including and after "instruct: "
|
91 |
+
output_text = output_text.split("instruct: ")[0]
|
92 |
+
|
93 |
+
return output_text
|
|