Spaces:
Runtime error
Runtime error
Merge pull request #15 from mail-ecnu/feat-proxy
Browse files- deciders/act.py +13 -10
- deciders/utils.py +1 -0
- envs/mujoco/ant_translator.py +1 -2
deciders/act.py
CHANGED
@@ -11,7 +11,7 @@ from memory.env_history import EnvironmentHistory
|
|
11 |
import tiktoken
|
12 |
import json
|
13 |
import re
|
14 |
-
from .utils import run_chain
|
15 |
from gym.spaces import Discrete
|
16 |
|
17 |
class RandomAct():
|
@@ -156,12 +156,15 @@ class NaiveAct(gpt):
|
|
156 |
prompt = f"{game_description}\n{goal_description}\n{fewshot_examples}\nCurrent {state_description}\n{action_description} "
|
157 |
prompt += "Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Your Action is: "
|
158 |
print(f"prompt is {prompt}")
|
159 |
-
res =
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
165 |
return prompt, res
|
166 |
|
167 |
def _add_history_before_action(self, game_description, goal_description, state_description):
|
@@ -210,8 +213,8 @@ class NaiveAct(gpt):
|
|
210 |
my_mem += f"{self.env_history.get_histories(self.mem_num)}"
|
211 |
|
212 |
|
213 |
-
prompt,
|
214 |
-
action_str =
|
215 |
print(f'my anwser is {action_str}')
|
216 |
action = self.parser.parse(response).action
|
217 |
self._add_history_after_action(action)
|
@@ -219,7 +222,7 @@ class NaiveAct(gpt):
|
|
219 |
self.logger.info(f'The optimal action is: {action}.')
|
220 |
if env_info.get('history'):
|
221 |
self.logger.info(f'History: {history_to_str(env_info["history"])}')
|
222 |
-
return action, prompt,
|
223 |
|
224 |
def _read_mem(self, ):
|
225 |
memory = self.memory
|
|
|
11 |
import tiktoken
|
12 |
import json
|
13 |
import re
|
14 |
+
from .utils import run_chain, get_completion, get_chat
|
15 |
from gym.spaces import Discrete
|
16 |
|
17 |
class RandomAct():
|
|
|
156 |
prompt = f"{game_description}\n{goal_description}\n{fewshot_examples}\nCurrent {state_description}\n{action_description} "
|
157 |
prompt += "Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Your Action is: "
|
158 |
print(f"prompt is {prompt}")
|
159 |
+
# res = get_chat(prompt, self.args.api_type, self.args.gpt_version, self.temperature, self.max_tokens)
|
160 |
+
res = get_chat(prompt, api_type=self.args.api_type, model=self.args.gpt_version, engine=self.args.gpt_version, temperature=self.temperature, max_tokens=self.max_tokens)
|
161 |
+
# openai.ChatCompletion.create(
|
162 |
+
# engine=self.args.gpt_version,
|
163 |
+
# # model=self.args.gpt_version,
|
164 |
+
# prompt=prompt,
|
165 |
+
# temperature=self.temperature,
|
166 |
+
# max_tokens=self.max_tokens,
|
167 |
+
# )
|
168 |
return prompt, res
|
169 |
|
170 |
def _add_history_before_action(self, game_description, goal_description, state_description):
|
|
|
213 |
my_mem += f"{self.env_history.get_histories(self.mem_num)}"
|
214 |
|
215 |
|
216 |
+
prompt, response = self.response(state_description, action_description, env_info, game_description, goal_description, my_mem)
|
217 |
+
action_str = response
|
218 |
print(f'my anwser is {action_str}')
|
219 |
action = self.parser.parse(response).action
|
220 |
self._add_history_after_action(action)
|
|
|
222 |
self.logger.info(f'The optimal action is: {action}.')
|
223 |
if env_info.get('history'):
|
224 |
self.logger.info(f'History: {history_to_str(env_info["history"])}')
|
225 |
+
return action, prompt, response, 0, 0
|
226 |
|
227 |
def _read_mem(self, ):
|
228 |
memory = self.memory
|
deciders/utils.py
CHANGED
@@ -54,6 +54,7 @@ def get_completion(prompt: str, api_type: str = "azure", engine: str = "gpt-35-t
|
|
54 |
temperature=temperature,
|
55 |
# request_timeout = 1
|
56 |
)
|
|
|
57 |
return response.choices[0]["message"]["content"]
|
58 |
|
59 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
|
|
54 |
temperature=temperature,
|
55 |
# request_timeout = 1
|
56 |
)
|
57 |
+
import pdb; pdb.set_trace()
|
58 |
return response.choices[0]["message"]["content"]
|
59 |
|
60 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
envs/mujoco/ant_translator.py
CHANGED
@@ -94,8 +94,7 @@ class GameDescriber:
|
|
94 |
|
95 |
def describe_action(self):
|
96 |
return (
|
97 |
-
"Your next move: \n Please
|
98 |
-
"representing the torques to be applied at the eight hinge joints of the ant."
|
99 |
)
|
100 |
|
101 |
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
|
|
94 |
|
95 |
def describe_action(self):
|
96 |
return (
|
97 |
+
"Your next move: \n Please choose your action which applies torques at the eight hinge joints of the ant. It be a list of eight numerical values and each value is within the range of [-1,1]."
|
|
|
98 |
)
|
99 |
|
100 |
class BasicStateSequenceTranslator(BasicLevelTranslator):
|