Jarvis commited on
Commit
841d805
2 Parent(s): 59cf660 3ffd15b

Merge pull request #15 from mail-ecnu/feat-proxy

Browse files
deciders/act.py CHANGED
@@ -11,7 +11,7 @@ from memory.env_history import EnvironmentHistory
11
  import tiktoken
12
  import json
13
  import re
14
- from .utils import run_chain
15
  from gym.spaces import Discrete
16
 
17
  class RandomAct():
@@ -156,12 +156,15 @@ class NaiveAct(gpt):
156
  prompt = f"{game_description}\n{goal_description}\n{fewshot_examples}\nCurrent {state_description}\n{action_description} "
157
  prompt += "Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Your Action is: "
158
  print(f"prompt is {prompt}")
159
- res = openai.Completion.create(
160
- engine=self.args.gpt_version,
161
- prompt=prompt,
162
- temperature=self.temperature,
163
- max_tokens=self.max_tokens,
164
- )
 
 
 
165
  return prompt, res
166
 
167
  def _add_history_before_action(self, game_description, goal_description, state_description):
@@ -210,8 +213,8 @@ class NaiveAct(gpt):
210
  my_mem += f"{self.env_history.get_histories(self.mem_num)}"
211
 
212
 
213
- prompt, res = self.response(state_description, action_description, env_info, game_description, goal_description, my_mem)
214
- action_str = res.choices[0].text.strip()
215
  print(f'my anwser is {action_str}')
216
  action = self.parser.parse(response).action
217
  self._add_history_after_action(action)
@@ -219,7 +222,7 @@ class NaiveAct(gpt):
219
  self.logger.info(f'The optimal action is: {action}.')
220
  if env_info.get('history'):
221
  self.logger.info(f'History: {history_to_str(env_info["history"])}')
222
- return action, prompt, res, 0, 0
223
 
224
  def _read_mem(self, ):
225
  memory = self.memory
 
11
  import tiktoken
12
  import json
13
  import re
14
+ from .utils import run_chain, get_completion, get_chat
15
  from gym.spaces import Discrete
16
 
17
  class RandomAct():
 
156
  prompt = f"{game_description}\n{goal_description}\n{fewshot_examples}\nCurrent {state_description}\n{action_description} "
157
  prompt += "Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Your Action is: "
158
  print(f"prompt is {prompt}")
159
+ # res = get_chat(prompt, self.args.api_type, self.args.gpt_version, self.temperature, self.max_tokens)
160
+ res = get_chat(prompt, api_type=self.args.api_type, model=self.args.gpt_version, engine=self.args.gpt_version, temperature=self.temperature, max_tokens=self.max_tokens)
161
+ # openai.ChatCompletion.create(
162
+ # engine=self.args.gpt_version,
163
+ # # model=self.args.gpt_version,
164
+ # prompt=prompt,
165
+ # temperature=self.temperature,
166
+ # max_tokens=self.max_tokens,
167
+ # )
168
  return prompt, res
169
 
170
  def _add_history_before_action(self, game_description, goal_description, state_description):
 
213
  my_mem += f"{self.env_history.get_histories(self.mem_num)}"
214
 
215
 
216
+ prompt, response = self.response(state_description, action_description, env_info, game_description, goal_description, my_mem)
217
+ action_str = response
218
  print(f'my anwser is {action_str}')
219
  action = self.parser.parse(response).action
220
  self._add_history_after_action(action)
 
222
  self.logger.info(f'The optimal action is: {action}.')
223
  if env_info.get('history'):
224
  self.logger.info(f'History: {history_to_str(env_info["history"])}')
225
+ return action, prompt, response, 0, 0
226
 
227
  def _read_mem(self, ):
228
  memory = self.memory
deciders/utils.py CHANGED
@@ -54,6 +54,7 @@ def get_completion(prompt: str, api_type: str = "azure", engine: str = "gpt-35-t
54
  temperature=temperature,
55
  # request_timeout = 1
56
  )
 
57
  return response.choices[0]["message"]["content"]
58
 
59
  # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
 
54
  temperature=temperature,
55
  # request_timeout = 1
56
  )
57
+ import pdb; pdb.set_trace()
58
  return response.choices[0]["message"]["content"]
59
 
60
  # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
envs/mujoco/ant_translator.py CHANGED
@@ -94,8 +94,7 @@ class GameDescriber:
94
 
95
  def describe_action(self):
96
  return (
97
- "Your next move: \n Please provide a list of eight numerical values, each within the range of [-1,1], "
98
- "representing the torques to be applied at the eight hinge joints of the ant."
99
  )
100
 
101
  class BasicStateSequenceTranslator(BasicLevelTranslator):
 
94
 
95
  def describe_action(self):
96
  return (
97
+ "Your next move: \n Please choose your action which applies torques at the eight hinge joints of the ant. It be a list of eight numerical values and each value is within the range of [-1,1]."
 
98
  )
99
 
100
  class BasicStateSequenceTranslator(BasicLevelTranslator):