Upload predict.py with huggingface_hub
Browse files- predict.py +33 -5
predict.py
CHANGED
@@ -185,6 +185,8 @@ def get_conversation_template(model_path: str) -> Conversation:
|
|
185 |
"""Get the default conversation template."""
|
186 |
if "aquila-v1" in model_path:
|
187 |
return get_conv_template("aquila-v1")
|
|
|
|
|
188 |
elif "aquila-chat" in model_path:
|
189 |
return get_conv_template("aquila-chat")
|
190 |
elif "aquila-legacy" in model_path:
|
@@ -252,6 +254,21 @@ register_conv_template(
|
|
252 |
)
|
253 |
)
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
if __name__ == "__main__":
|
257 |
print("aquila template:")
|
@@ -294,6 +311,17 @@ if __name__ == "__main__":
|
|
294 |
|
295 |
print("\n")
|
296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
def set_random_seed(seed):
|
298 |
"""Set random seed for reproducability."""
|
299 |
if seed is not None and seed > 0:
|
@@ -330,9 +358,9 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token,
|
|
330 |
return example
|
331 |
|
332 |
def predict(model, text, tokenizer=None,
|
333 |
-
max_gen_len=200, top_p=0.
|
334 |
-
seed=
|
335 |
-
temperature=0
|
336 |
sft=True, convo_template = "",
|
337 |
device = "cuda",
|
338 |
model_name="AquilaChat2-7B",
|
@@ -346,8 +374,8 @@ def predict(model, text, tokenizer=None,
|
|
346 |
|
347 |
template_map = {"AquilaChat2-7B": "aquila-v1",
|
348 |
"AquilaChat2-34B": "aquila-legacy",
|
349 |
-
"AquilaChat2-7B-16K": "aquila",
|
350 |
"AquilaChat2-70B-Expr": "aquila-v2",
|
|
|
351 |
"AquilaChat2-34B-16K": "aquila"}
|
352 |
if not convo_template:
|
353 |
convo_template=template_map.get(model_name, "aquila-chat")
|
@@ -357,7 +385,7 @@ def predict(model, text, tokenizer=None,
|
|
357 |
topk = 1
|
358 |
temperature = 1.0
|
359 |
if sft:
|
360 |
-
tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=
|
361 |
tokens = torch.tensor(tokens)[None,].to(device)
|
362 |
else :
|
363 |
tokens = tokenizer.encode_plus(text)["input_ids"]
|
|
|
185 |
"""Get the default conversation template."""
|
186 |
if "aquila-v1" in model_path:
|
187 |
return get_conv_template("aquila-v1")
|
188 |
+
elif "aquila-v2" in model_path:
|
189 |
+
return get_conv_template("aquila-v2")
|
190 |
elif "aquila-chat" in model_path:
|
191 |
return get_conv_template("aquila-chat")
|
192 |
elif "aquila-legacy" in model_path:
|
|
|
254 |
)
|
255 |
)
|
256 |
|
257 |
+
register_conv_template(
|
258 |
+
Conversation(
|
259 |
+
name="aquila-v2",
|
260 |
+
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
261 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
262 |
+
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
|
263 |
+
messages=(),
|
264 |
+
offset=0,
|
265 |
+
sep_style=SeparatorStyle.NO_COLON_TWO,
|
266 |
+
sep="",
|
267 |
+
sep2="</s>",
|
268 |
+
stop_str=["</s>", "<|endoftext|>", "<|startofpiece|>", "<|endofpiece|>"],
|
269 |
+
)
|
270 |
+
)
|
271 |
+
|
272 |
|
273 |
if __name__ == "__main__":
|
274 |
print("aquila template:")
|
|
|
311 |
|
312 |
print("\n")
|
313 |
|
314 |
+
print("aquila-v2 template:")
|
315 |
+
conv = get_conv_template("aquila-v2")
|
316 |
+
conv.append_message(conv.roles[0], "Hello!")
|
317 |
+
conv.append_message(conv.roles[1], "Hi!")
|
318 |
+
conv.append_message(conv.roles[0], "How are you?")
|
319 |
+
conv.append_message(conv.roles[1], None)
|
320 |
+
print(conv.get_prompt())
|
321 |
+
|
322 |
+
print("\n")
|
323 |
+
|
324 |
+
|
325 |
def set_random_seed(seed):
|
326 |
"""Set random seed for reproducability."""
|
327 |
if seed is not None and seed > 0:
|
|
|
358 |
return example
|
359 |
|
360 |
def predict(model, text, tokenizer=None,
|
361 |
+
max_gen_len=200, top_p=0.9,
|
362 |
+
seed=123, topk=15,
|
363 |
+
temperature=1.0,
|
364 |
sft=True, convo_template = "",
|
365 |
device = "cuda",
|
366 |
model_name="AquilaChat2-7B",
|
|
|
374 |
|
375 |
template_map = {"AquilaChat2-7B": "aquila-v1",
|
376 |
"AquilaChat2-34B": "aquila-legacy",
|
|
|
377 |
"AquilaChat2-70B-Expr": "aquila-v2",
|
378 |
+
"AquilaChat2-7B-16K": "aquila",
|
379 |
"AquilaChat2-34B-16K": "aquila"}
|
380 |
if not convo_template:
|
381 |
convo_template=template_map.get(model_name, "aquila-chat")
|
|
|
385 |
topk = 1
|
386 |
temperature = 1.0
|
387 |
if sft:
|
388 |
+
tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=20480, convo_template=convo_template)
|
389 |
tokens = torch.tensor(tokens)[None,].to(device)
|
390 |
else :
|
391 |
tokens = tokenizer.encode_plus(text)["input_ids"]
|