Spaces:
Runtime error
Runtime error
hzxwonder
commited on
Commit
·
c640769
1
Parent(s):
30becf4
update
Browse files- deciders/utils.py +26 -25
- distillers/guider.py +4 -4
- distillers/self_reflection.py +1 -1
- distillers/traj_prompt_summarizer.py +1 -1
deciders/utils.py
CHANGED
@@ -25,23 +25,23 @@ def run_chain(chain, *args, **kwargs):
|
|
25 |
return chain.run(*args, **kwargs)
|
26 |
|
27 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
28 |
-
def get_completion(prompt: str, engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None) -> str:
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
|
43 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
44 |
-
def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
|
45 |
assert model != "text-davinci-003"
|
46 |
messages = [
|
47 |
{
|
@@ -49,14 +49,15 @@ def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-tur
|
|
49 |
"content": prompt
|
50 |
}
|
51 |
]
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
62 |
|
|
|
25 |
return chain.run(*args, **kwargs)
|
26 |
|
27 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
28 |
+
def get_completion(prompt: str, api_type: str = "azure", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None) -> str:
|
29 |
+
if api_type == "azure":
|
30 |
+
response = openai.Completion.create(
|
31 |
+
engine=engine,
|
32 |
+
prompt=prompt,
|
33 |
+
temperature=temperature,
|
34 |
+
max_tokens=max_tokens,
|
35 |
+
top_p=1,
|
36 |
+
frequency_penalty=0.0,
|
37 |
+
presence_penalty=0.0,
|
38 |
+
stop=stop_strs,
|
39 |
+
# request_timeout = 1
|
40 |
+
)
|
41 |
+
return response.choices[0].text
|
42 |
|
43 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
44 |
+
def get_chat(prompt: str, api_type: str = "azure", model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
|
45 |
assert model != "text-davinci-003"
|
46 |
messages = [
|
47 |
{
|
|
|
49 |
"content": prompt
|
50 |
}
|
51 |
]
|
52 |
+
if api_type == "azure":
|
53 |
+
response = openai.ChatCompletion.create(
|
54 |
+
model=model,
|
55 |
+
engine=engine,
|
56 |
+
messages=messages,
|
57 |
+
max_tokens=max_tokens,
|
58 |
+
stop=stop_strs,
|
59 |
+
temperature=temperature,
|
60 |
+
# request_timeout = 1
|
61 |
+
)
|
62 |
+
return response.choices[0]["message"]["content"]
|
63 |
|
distillers/guider.py
CHANGED
@@ -67,7 +67,7 @@ class Guidance_Generator():
|
|
67 |
|
68 |
def generate_summary(self, traj, post_memory):
|
69 |
query = self._generate_summary_query(traj, post_memory)
|
70 |
-
summary = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
|
71 |
logger.info(f'[Reflexion Memory]The summary prompt is: {query}.')
|
72 |
logger.info(f'[Reflexion Memory]The summary response is: {summary}.')
|
73 |
return summary
|
@@ -78,7 +78,7 @@ class Guidance_Generator():
|
|
78 |
for i, m in enumerate(post_memory):
|
79 |
query += f'Episode #{i}: {m}\n'
|
80 |
query += '\n Identify and summarize the key information that can be exploited to improve performance of the player.'
|
81 |
-
insight = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
|
82 |
# import pdb;pdb.set_trace()
|
83 |
logger.info(f'[Reflexion Memory]The insight prompt is: {query}.')
|
84 |
logger.info(f'[Reflexion Memory]The insight response is: {insight}.')
|
@@ -96,7 +96,7 @@ class Guidance_Generator():
|
|
96 |
query += f"\n The main aim for you is to help the human player win the game in the last episode. He has only {max(max_num_trials-len(post_memory), 1)} episodes left to try.You can give suggestions before each episode. Then what is your suggestion for his next episode? Please provide simple, concise answers suitable for a six-year-old child, focusing on the following in item list format: 1. What game-relevant knowledge is critical to determine the optimal policy. Notice that the knowledge should be obtainable by interacting with the environment and helpful for the decisions.\n 2. How should the player conduct exploration in the next episode to acquire this information?\n3. How can the player exploit the information obtained to achieve higher performance in subsequent episodes?\n 4. How should exploration and exploitation be balanced to improve performance in the next episode?\n"
|
97 |
|
98 |
# TODO: consider the inconsistency between past suggestion and past memory.
|
99 |
-
suggestion = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
|
100 |
self.suggestion = suggestion
|
101 |
logger.info(f'[Reflexion Memory]The suggestion prompt is: {query}.')
|
102 |
logger.info(f'[Reflexion Memory]The suggestion response is: {suggestion}.')
|
@@ -107,7 +107,7 @@ class Guidance_Generator():
|
|
107 |
reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
|
108 |
else:
|
109 |
reflection_query = self._generate_summary_query(traj, memory)
|
110 |
-
reflection = get_completion(reflection_query, engine=self.args.gpt_version)
|
111 |
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
112 |
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
113 |
return reflection
|
|
|
67 |
|
68 |
def generate_summary(self, traj, post_memory):
|
69 |
query = self._generate_summary_query(traj, post_memory)
|
70 |
+
summary = get_chat(query, api_type=self.args.api_type, model=self.args.gpt_version, engine=self.args.gpt_version)
|
71 |
logger.info(f'[Reflexion Memory]The summary prompt is: {query}.')
|
72 |
logger.info(f'[Reflexion Memory]The summary response is: {summary}.')
|
73 |
return summary
|
|
|
78 |
for i, m in enumerate(post_memory):
|
79 |
query += f'Episode #{i}: {m}\n'
|
80 |
query += '\n Identify and summarize the key information that can be exploited to improve performance of the player.'
|
81 |
+
insight = get_chat(query, api_type=self.args.api_type, model=self.args.gpt_version, engine=self.args.gpt_version)
|
82 |
# import pdb;pdb.set_trace()
|
83 |
logger.info(f'[Reflexion Memory]The insight prompt is: {query}.')
|
84 |
logger.info(f'[Reflexion Memory]The insight response is: {insight}.')
|
|
|
96 |
query += f"\n The main aim for you is to help the human player win the game in the last episode. He has only {max(max_num_trials-len(post_memory), 1)} episodes left to try.You can give suggestions before each episode. Then what is your suggestion for his next episode? Please provide simple, concise answers suitable for a six-year-old child, focusing on the following in item list format: 1. What game-relevant knowledge is critical to determine the optimal policy. Notice that the knowledge should be obtainable by interacting with the environment and helpful for the decisions.\n 2. How should the player conduct exploration in the next episode to acquire this information?\n3. How can the player exploit the information obtained to achieve higher performance in subsequent episodes?\n 4. How should exploration and exploitation be balanced to improve performance in the next episode?\n"
|
97 |
|
98 |
# TODO: consider the inconsistency between past suggestion and past memory.
|
99 |
+
suggestion = get_chat(query, api_type=self.args.api_type, model=self.args.gpt_version, engine=self.args.gpt_version)
|
100 |
self.suggestion = suggestion
|
101 |
logger.info(f'[Reflexion Memory]The suggestion prompt is: {query}.')
|
102 |
logger.info(f'[Reflexion Memory]The suggestion response is: {suggestion}.')
|
|
|
107 |
reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
|
108 |
else:
|
109 |
reflection_query = self._generate_summary_query(traj, memory)
|
110 |
+
reflection = get_completion(reflection_query, api_type=self.args.api_type, engine=self.args.gpt_version)
|
111 |
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
112 |
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
113 |
return reflection
|
distillers/self_reflection.py
CHANGED
@@ -53,7 +53,7 @@ class RefletionGenerator():
|
|
53 |
reflection_query = self._generate_reflection_query(traj, memory[-max_len_mem:])
|
54 |
else:
|
55 |
reflection_query = self._generate_reflection_query(traj, memory)
|
56 |
-
reflection = get_completion(reflection_query, engine=self.args.gpt_version)
|
57 |
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
58 |
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
59 |
return reflection
|
|
|
53 |
reflection_query = self._generate_reflection_query(traj, memory[-max_len_mem:])
|
54 |
else:
|
55 |
reflection_query = self._generate_reflection_query(traj, memory)
|
56 |
+
reflection = get_completion(reflection_query, api_type=self.args.api_type, engine=self.args.gpt_version)
|
57 |
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
58 |
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
59 |
return reflection
|
distillers/traj_prompt_summarizer.py
CHANGED
@@ -54,7 +54,7 @@ class TrajPromptSummarizer():
|
|
54 |
reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
|
55 |
else:
|
56 |
reflection_query = self._generate_summary_query(traj, memory)
|
57 |
-
reflection = get_completion(reflection_query, engine=self.args.gpt_version)
|
58 |
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
59 |
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
60 |
return reflection
|
|
|
54 |
reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
|
55 |
else:
|
56 |
reflection_query = self._generate_summary_query(traj, memory)
|
57 |
+
reflection = get_completion(reflection_query, api_type=self.args.api_type, engine=self.args.gpt_version)
|
58 |
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
59 |
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
60 |
return reflection
|