hzxwonder commited on
Commit
c640769
·
1 Parent(s): 30becf4
deciders/utils.py CHANGED
@@ -25,23 +25,23 @@ def run_chain(chain, *args, **kwargs):
25
  return chain.run(*args, **kwargs)
26
 
27
  # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
28
- def get_completion(prompt: str, engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None) -> str:
29
- response = openai.Completion.create(
30
- model=engine,
31
- engine=engine,
32
- prompt=prompt,
33
- temperature=temperature,
34
- max_tokens=max_tokens,
35
- top_p=1,
36
- frequency_penalty=0.0,
37
- presence_penalty=0.0,
38
- stop=stop_strs,
39
- # request_timeout = 1
40
- )
41
- return response.choices[0].text
42
 
43
  # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
44
- def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
45
  assert model != "text-davinci-003"
46
  messages = [
47
  {
@@ -49,14 +49,15 @@ def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-tur
49
  "content": prompt
50
  }
51
  ]
52
- response = openai.ChatCompletion.create(
53
- model=model,
54
- engine=engine,
55
- messages=messages,
56
- max_tokens=max_tokens,
57
- stop=stop_strs,
58
- temperature=temperature,
59
- # request_timeout = 1
60
- )
61
- return response.choices[0]["message"]["content"]
 
62
 
 
25
  return chain.run(*args, **kwargs)
26
 
27
  # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
28
+ def get_completion(prompt: str, api_type: str = "azure", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None) -> str:
29
+ if api_type == "azure":
30
+ response = openai.Completion.create(
31
+ engine=engine,
32
+ prompt=prompt,
33
+ temperature=temperature,
34
+ max_tokens=max_tokens,
35
+ top_p=1,
36
+ frequency_penalty=0.0,
37
+ presence_penalty=0.0,
38
+ stop=stop_strs,
39
+ # request_timeout = 1
40
+ )
41
+ return response.choices[0].text
42
 
43
  # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
44
+ def get_chat(prompt: str, api_type: str = "azure", model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
45
  assert model != "text-davinci-003"
46
  messages = [
47
  {
 
49
  "content": prompt
50
  }
51
  ]
52
+ if api_type == "azure":
53
+ response = openai.ChatCompletion.create(
54
+ model=model,
55
+ engine=engine,
56
+ messages=messages,
57
+ max_tokens=max_tokens,
58
+ stop=stop_strs,
59
+ temperature=temperature,
60
+ # request_timeout = 1
61
+ )
62
+ return response.choices[0]["message"]["content"]
63
 
distillers/guider.py CHANGED
@@ -67,7 +67,7 @@ class Guidance_Generator():
67
 
68
  def generate_summary(self, traj, post_memory):
69
  query = self._generate_summary_query(traj, post_memory)
70
- summary = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
71
  logger.info(f'[Reflexion Memory]The summary prompt is: {query}.')
72
  logger.info(f'[Reflexion Memory]The summary response is: {summary}.')
73
  return summary
@@ -78,7 +78,7 @@ class Guidance_Generator():
78
  for i, m in enumerate(post_memory):
79
  query += f'Episode #{i}: {m}\n'
80
  query += '\n Identify and summarize the key information that can be exploited to improve performance of the player.'
81
- insight = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
82
  # import pdb;pdb.set_trace()
83
  logger.info(f'[Reflexion Memory]The insight prompt is: {query}.')
84
  logger.info(f'[Reflexion Memory]The insight response is: {insight}.')
@@ -96,7 +96,7 @@ class Guidance_Generator():
96
  query += f"\n The main aim for you is to help the human player win the game in the last episode. He has only {max(max_num_trials-len(post_memory), 1)} episodes left to try.You can give suggestions before each episode. Then what is your suggestion for his next episode? Please provide simple, concise answers suitable for a six-year-old child, focusing on the following in item list format: 1. What game-relevant knowledge is critical to determine the optimal policy. Notice that the knowledge should be obtainable by interacting with the environment and helpful for the decisions.\n 2. How should the player conduct exploration in the next episode to acquire this information?\n3. How can the player exploit the information obtained to achieve higher performance in subsequent episodes?\n 4. How should exploration and exploitation be balanced to improve performance in the next episode?\n"
97
 
98
  # TODO: consider the inconsistency between past suggestion and past memory.
99
- suggestion = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
100
  self.suggestion = suggestion
101
  logger.info(f'[Reflexion Memory]The suggestion prompt is: {query}.')
102
  logger.info(f'[Reflexion Memory]The suggestion response is: {suggestion}.')
@@ -107,7 +107,7 @@ class Guidance_Generator():
107
  reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
108
  else:
109
  reflection_query = self._generate_summary_query(traj, memory)
110
- reflection = get_completion(reflection_query, engine=self.args.gpt_version)
111
  logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
112
  logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
113
  return reflection
 
67
 
68
  def generate_summary(self, traj, post_memory):
69
  query = self._generate_summary_query(traj, post_memory)
70
+ summary = get_chat(query, api_type=self.args.api_type, model=self.args.gpt_version, engine=self.args.gpt_version)
71
  logger.info(f'[Reflexion Memory]The summary prompt is: {query}.')
72
  logger.info(f'[Reflexion Memory]The summary response is: {summary}.')
73
  return summary
 
78
  for i, m in enumerate(post_memory):
79
  query += f'Episode #{i}: {m}\n'
80
  query += '\n Identify and summarize the key information that can be exploited to improve performance of the player.'
81
+ insight = get_chat(query, api_type=self.args.api_type, model=self.args.gpt_version, engine=self.args.gpt_version)
82
  # import pdb;pdb.set_trace()
83
  logger.info(f'[Reflexion Memory]The insight prompt is: {query}.')
84
  logger.info(f'[Reflexion Memory]The insight response is: {insight}.')
 
96
  query += f"\n The main aim for you is to help the human player win the game in the last episode. He has only {max(max_num_trials-len(post_memory), 1)} episodes left to try.You can give suggestions before each episode. Then what is your suggestion for his next episode? Please provide simple, concise answers suitable for a six-year-old child, focusing on the following in item list format: 1. What game-relevant knowledge is critical to determine the optimal policy. Notice that the knowledge should be obtainable by interacting with the environment and helpful for the decisions.\n 2. How should the player conduct exploration in the next episode to acquire this information?\n3. How can the player exploit the information obtained to achieve higher performance in subsequent episodes?\n 4. How should exploration and exploitation be balanced to improve performance in the next episode?\n"
97
 
98
  # TODO: consider the inconsistency between past suggestion and past memory.
99
+ suggestion = get_chat(query, api_type=self.args.api_type, model=self.args.gpt_version, engine=self.args.gpt_version)
100
  self.suggestion = suggestion
101
  logger.info(f'[Reflexion Memory]The suggestion prompt is: {query}.')
102
  logger.info(f'[Reflexion Memory]The suggestion response is: {suggestion}.')
 
107
  reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
108
  else:
109
  reflection_query = self._generate_summary_query(traj, memory)
110
+ reflection = get_completion(reflection_query, api_type=self.args.api_type, engine=self.args.gpt_version)
111
  logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
112
  logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
113
  return reflection
distillers/self_reflection.py CHANGED
@@ -53,7 +53,7 @@ class RefletionGenerator():
53
  reflection_query = self._generate_reflection_query(traj, memory[-max_len_mem:])
54
  else:
55
  reflection_query = self._generate_reflection_query(traj, memory)
56
- reflection = get_completion(reflection_query, engine=self.args.gpt_version)
57
  logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
58
  logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
59
  return reflection
 
53
  reflection_query = self._generate_reflection_query(traj, memory[-max_len_mem:])
54
  else:
55
  reflection_query = self._generate_reflection_query(traj, memory)
56
+ reflection = get_completion(reflection_query, api_type=self.args.api_type, engine=self.args.gpt_version)
57
  logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
58
  logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
59
  return reflection
distillers/traj_prompt_summarizer.py CHANGED
@@ -54,7 +54,7 @@ class TrajPromptSummarizer():
54
  reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
55
  else:
56
  reflection_query = self._generate_summary_query(traj, memory)
57
- reflection = get_completion(reflection_query, engine=self.args.gpt_version)
58
  logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
59
  logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
60
  return reflection
 
54
  reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
55
  else:
56
  reflection_query = self._generate_summary_query(traj, memory)
57
+ reflection = get_completion(reflection_query, api_type=self.args.api_type, engine=self.args.gpt_version)
58
  logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
59
  logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
60
  return reflection