Spaces:
Runtime error
Runtime error
Cartinoe5930
commited on
Commit
•
823c850
1
Parent(s):
4cb1f6b
Update model_inference.py
Browse files- model_inference.py +25 -25
model_inference.py
CHANGED
@@ -47,31 +47,6 @@ def generate_question(agents, question):
|
|
47 |
|
48 |
return agent_contexts, content
|
49 |
|
50 |
-
def generate_answer(model, formatted_prompt):
|
51 |
-
API_URL = endpoint_dict[model]
|
52 |
-
headers = {"Authorization": f"Bearer {args.auth_token}"}
|
53 |
-
payload = {"inputs": formatted_prompt}
|
54 |
-
try:
|
55 |
-
resp = requests.post(API_URL, json=payload, headers=headers)
|
56 |
-
response = resp.json()
|
57 |
-
except:
|
58 |
-
print("retrying due to an error......")
|
59 |
-
time.sleep(5)
|
60 |
-
return generate_answer(API_URL, headers, payload)
|
61 |
-
|
62 |
-
return {"model": model, "content": response[0]["generated_text"].split(prompt_dict[model]["response_split"])[-1]}
|
63 |
-
|
64 |
-
def prompt_formatting(model, instruction, cot):
|
65 |
-
if model == "alpaca" or model == "orca":
|
66 |
-
prompt = prompt_dict[model]["prompt_no_input"]
|
67 |
-
else:
|
68 |
-
prompt = prompt_dict[model]["prompt"]
|
69 |
-
|
70 |
-
if cot:
|
71 |
-
instruction += "Let's think step by step."
|
72 |
-
|
73 |
-
return {"model": model, "content": prompt.format(instruction)}
|
74 |
-
|
75 |
def Inference(model_list, question, API_KEY, auth_token, round, cot):
|
76 |
if len(model_list) != 3:
|
77 |
raise ValueError("Please choose just '3' models! Neither more nor less!")
|
@@ -80,6 +55,31 @@ def Inference(model_list, question, API_KEY, auth_token, round, cot):
|
|
80 |
|
81 |
prompt_dict, endpoint_dict = load_json("src/prompt_template.json", "src/inference_endpoint.json")
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
agents = len(model_list)
|
84 |
rounds = round
|
85 |
|
|
|
47 |
|
48 |
return agent_contexts, content
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def Inference(model_list, question, API_KEY, auth_token, round, cot):
|
51 |
if len(model_list) != 3:
|
52 |
raise ValueError("Please choose just '3' models! Neither more nor less!")
|
|
|
55 |
|
56 |
prompt_dict, endpoint_dict = load_json("src/prompt_template.json", "src/inference_endpoint.json")
|
57 |
|
58 |
+
def generate_answer(model, formatted_prompt):
|
59 |
+
API_URL = endpoint_dict[model]["API_URL"]
|
60 |
+
headers = endpoint_dict[model]["headers"]
|
61 |
+
payload = {"inputs": formatted_prompt}
|
62 |
+
try:
|
63 |
+
resp = requests.post(API_URL, json=payload, headers=headers)
|
64 |
+
response = resp.json()
|
65 |
+
except:
|
66 |
+
print("retrying due to an error......")
|
67 |
+
time.sleep(5)
|
68 |
+
return generate_answer(API_URL, headers, payload)
|
69 |
+
|
70 |
+
return {"model": model, "content": response[0]["generated_text"].split(prompt_dict[model]["response_split"])[-1]}
|
71 |
+
|
72 |
+
def prompt_formatting(model, instruction, cot):
|
73 |
+
if model == "alpaca" or model == "orca":
|
74 |
+
prompt = prompt_dict[model]["prompt_no_input"]
|
75 |
+
else:
|
76 |
+
prompt = prompt_dict[model]["prompt"]
|
77 |
+
|
78 |
+
if cot:
|
79 |
+
instruction += "Let's think step by step."
|
80 |
+
|
81 |
+
return {"model": model, "content": prompt.format(instruction)}
|
82 |
+
|
83 |
agents = len(model_list)
|
84 |
rounds = round
|
85 |
|