Cartinoe5930 commited on
Commit
823c850
1 Parent(s): 4cb1f6b

Update model_inference.py

Browse files
Files changed (1) hide show
  1. model_inference.py +25 -25
model_inference.py CHANGED
@@ -47,31 +47,6 @@ def generate_question(agents, question):
47
 
48
  return agent_contexts, content
49
 
50
- def generate_answer(model, formatted_prompt):
51
- API_URL = endpoint_dict[model]
52
- headers = {"Authorization": f"Bearer {args.auth_token}"}
53
- payload = {"inputs": formatted_prompt}
54
- try:
55
- resp = requests.post(API_URL, json=payload, headers=headers)
56
- response = resp.json()
57
- except:
58
- print("retrying due to an error......")
59
- time.sleep(5)
60
- return generate_answer(API_URL, headers, payload)
61
-
62
- return {"model": model, "content": response[0]["generated_text"].split(prompt_dict[model]["response_split"])[-1]}
63
-
64
- def prompt_formatting(model, instruction, cot):
65
- if model == "alpaca" or model == "orca":
66
- prompt = prompt_dict[model]["prompt_no_input"]
67
- else:
68
- prompt = prompt_dict[model]["prompt"]
69
-
70
- if cot:
71
- instruction += "Let's think step by step."
72
-
73
- return {"model": model, "content": prompt.format(instruction)}
74
-
75
  def Inference(model_list, question, API_KEY, auth_token, round, cot):
76
  if len(model_list) != 3:
77
  raise ValueError("Please choose just '3' models! Neither more nor less!")
@@ -80,6 +55,31 @@ def Inference(model_list, question, API_KEY, auth_token, round, cot):
80
 
81
  prompt_dict, endpoint_dict = load_json("src/prompt_template.json", "src/inference_endpoint.json")
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  agents = len(model_list)
84
  rounds = round
85
 
 
47
 
48
  return agent_contexts, content
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def Inference(model_list, question, API_KEY, auth_token, round, cot):
51
  if len(model_list) != 3:
52
  raise ValueError("Please choose just '3' models! Neither more nor less!")
 
55
 
56
  prompt_dict, endpoint_dict = load_json("src/prompt_template.json", "src/inference_endpoint.json")
57
 
58
+ def generate_answer(model, formatted_prompt):
59
+ API_URL = endpoint_dict[model]["API_URL"]
60
+ headers = endpoint_dict[model]["headers"]
61
+ payload = {"inputs": formatted_prompt}
62
+ try:
63
+ resp = requests.post(API_URL, json=payload, headers=headers)
64
+ response = resp.json()
65
+ except:
66
+ print("retrying due to an error......")
67
+ time.sleep(5)
68
+ return generate_answer(API_URL, headers, payload)
69
+
70
+ return {"model": model, "content": response[0]["generated_text"].split(prompt_dict[model]["response_split"])[-1]}
71
+
72
+ def prompt_formatting(model, instruction, cot):
73
+ if model == "alpaca" or model == "orca":
74
+ prompt = prompt_dict[model]["prompt_no_input"]
75
+ else:
76
+ prompt = prompt_dict[model]["prompt"]
77
+
78
+ if cot:
79
+ instruction += "Let's think step by step."
80
+
81
+ return {"model": model, "content": prompt.format(instruction)}
82
+
83
  agents = len(model_list)
84
  rounds = round
85