ewanlee commited on
Commit
eaa7556
1 Parent(s): 841d805

atari visualization with Gradio

Browse files
Files changed (4) hide show
  1. .gitignore +4 -1
  2. deciders/parser.py +3 -3
  3. environment.yaml +2 -1
  4. gradio_reflexion.py +312 -0
.gitignore CHANGED
@@ -186,4 +186,7 @@ main_jarvis.sh
186
  test*.py
187
  *.zip
188
  test_
189
- *.ipynb
 
 
 
 
186
  test*.py
187
  *.zip
188
  test_
189
+ *.ipynb
190
+
191
+ # gradio
192
+ flagged
deciders/parser.py CHANGED
@@ -7,10 +7,10 @@ class DisActionModel(BaseModel):
7
  @classmethod
8
  def create_validator(cls, max_action):
9
  @validator('action', allow_reuse=True)
10
- def action_is_valid(cls, field):
11
- if field not in range(1, max_action + 1):
12
  raise ValueError(f"Action is not valid ([1, {max_action}])!")
13
- return field
14
  return action_is_valid
15
 
16
  # Generate classes dynamically
 
7
  @classmethod
8
  def create_validator(cls, max_action):
9
  @validator('action', allow_reuse=True)
10
+ def action_is_valid(cls, info):
11
+ if info not in range(1, max_action + 1):
12
  raise ValueError(f"Action is not valid ([1, {max_action}])!")
13
+ return info
14
  return action_is_valid
15
 
16
  # Generate classes dynamically
environment.yaml CHANGED
@@ -186,4 +186,5 @@ dependencies:
186
  - win32-setctime==1.1.0
187
  - yarl==1.9.2
188
  - zipp==3.15.0
189
- - git+ssh://git@github.com/hyyh28/atari-representation-learning.git
 
 
186
  - win32-setctime==1.1.0
187
  - yarl==1.9.2
188
  - zipp==3.15.0
189
+ - git+ssh://git@github.com/hyyh28/atari-representation-learning.git
190
+ - gradio==4.13.0
gradio_reflexion.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import envs
2
+ import deciders
3
+ import distillers
4
+ import prompts as task_prompts
5
+ import datetime
6
+ import time
7
+ from envs.translator import InitSummarizer, CurrSummarizer, FutureSummarizer, Translator
8
+ import gym
9
+ import pandas as pd
10
+ import random
11
+ import datetime
12
+ from loguru import logger
13
+ from argparse import Namespace
14
+ import gradio as gr
15
+
16
+
17
+ def set_seed(seed):
18
+ random.seed(seed)
19
+
20
+ def main_progress(env_name, decider, prompt_level, num_trails, seed):
21
+ init_summarizer = env_name.split("-")[0] + '_init_translator'
22
+ curr_summarizer = env_name.split("-")[0] + '_basic_translator'
23
+ args = Namespace(
24
+ env_name=env_name,
25
+ init_summarizer=init_summarizer,
26
+ curr_summarizer=curr_summarizer,
27
+ decider=decider,
28
+ prompt_level=prompt_level,
29
+ num_trails=num_trails,
30
+ seed=seed,
31
+ future_summarizer=None,
32
+ env="base_env",
33
+ gpt_version="gpt-3.5-turbo",
34
+ render="rgb_array",
35
+ max_episode_len=200,
36
+ max_query_tokens=5000,
37
+ max_tokens=2000,
38
+ distiller="traj_distiller",
39
+ prompt_path=None,
40
+ use_short_mem=1,
41
+ short_mem_num=10,
42
+ is_only_local_obs=1,
43
+ api_type="azure",
44
+ )
45
+
46
+ if args.api_type != "azure" and args.api_type != "openai":
47
+ raise ValueError(f"The {args.api_type} is not supported, please use 'azure' or 'openai' !")
48
+
49
+ # Please note when using "azure", the model name is gpt-35-turbo while using "openai", the model name is "gpt-3.5-turbo"
50
+ if args.api_type == "azure":
51
+ if args.gpt_version == "gpt-3.5-turbo":
52
+ args.gpt_version = 'gpt-35-turbo'
53
+ elif args.api_type == "openai":
54
+ if args.gpt_version == "gpt-35-turbo":
55
+ args.gpt_version = 'gpt-3.5-turbo'
56
+
57
+ # Get the specified translator, environment, and ChatGPT model
58
+ env_class = envs.REGISTRY[args.env]
59
+ init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer], args)
60
+ curr_summarizer = CurrSummarizer(envs.REGISTRY[args.curr_summarizer])
61
+
62
+ if args.future_summarizer:
63
+ future_summarizer = FutureSummarizer(
64
+ envs.REGISTRY[args.future_summarizer],
65
+ envs.REGISTRY["cart_policies"],
66
+ future_horizon=args.future_horizon,
67
+ )
68
+ else:
69
+ future_summarizer = None
70
+
71
+ decider_class = deciders.REGISTRY[args.decider]
72
+ distiller_class = distillers.REGISTRY[args.distiller]
73
+ sampling_env = envs.REGISTRY["sampling_wrapper"](gym.make(args.env_name))
74
+ if args.prompt_level == 5:
75
+ prompts_class = task_prompts.REGISTRY[(args.env_name,args.decider)]()
76
+ else:
77
+ prompts_class = task_prompts.REGISTRY[(args.decider)]()
78
+ translator = Translator(
79
+ init_summarizer, curr_summarizer, future_summarizer, env=sampling_env
80
+ )
81
+ environment = env_class(
82
+ gym.make(args.env_name, render_mode=args.render), translator
83
+ )
84
+
85
+ logfile = (
86
+ f"llm.log/output-{args.env_name}-{args.decider}-{args.gpt_version}-l{args.prompt_level}"
87
+ f"-{datetime.datetime.now().timestamp()}.log"
88
+ )
89
+
90
+ logfile_reflexion = (
91
+ f"llm.log/memory-{args.env_name}-{args.decider}-{args.gpt_version}-l{args.prompt_level}"
92
+ f"-{datetime.datetime.now().timestamp()}.log"
93
+ )
94
+ my_distiller = distiller_class(logfile=logfile_reflexion,args=args)
95
+
96
+ args.game_description = environment.game_description
97
+ args.goal_description = environment.goal_description
98
+ args.action_description = environment.action_description
99
+ args.action_desc_dict = environment.action_desc_dict
100
+ args.reward_desc_dict = environment.reward_desc_dict
101
+
102
+ logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
103
+
104
+ decider = decider_class(environment.env.action_space, args, prompts_class, my_distiller, temperature=0.0, logger=logger, max_tokens=args.max_tokens)
105
+
106
+ # Evaluate the translator
107
+ utilities = []
108
+ df = pd.read_csv('record_reflexion.csv', sep=',')
109
+ filtered_df = df[(df['env'] == args.env_name) & (df['decider'] == 'expert') & (df['level'] == 1)]
110
+ expert_score = filtered_df['avg_score'].item()
111
+ seeds = [i for i in range(1000)]
112
+ # prompt_file = "prompt.txt"
113
+ # f = open(prompt_file,"w+")
114
+ num_trails = args.num_trails
115
+ if not "Blackjack" in args.env_name:
116
+ curriculums = 1
117
+ else:
118
+ curriculums = 20
119
+ for curriculum in range(curriculums):
120
+ for trail in range(num_trails):
121
+ if "Blackjack" in args.env_name:
122
+ seed = seeds[curriculum*curriculums + num_trails - trail - 1]
123
+ else:
124
+ seed = args.seed
125
+
126
+ # single run
127
+ # Reset the environment
128
+ if not "Blackjack" in args.env_name:
129
+ set_seed(args.seed)
130
+ seed = args.seed
131
+ # Reset the environment
132
+ state_description, env_info = environment.reset(seed=args.seed)
133
+ else:
134
+ set_seed(seed)
135
+ # Reset the environment
136
+ state_description, env_info = environment.reset(seed=seed)
137
+ game_description = environment.get_game_description()
138
+ goal_description = environment.get_goal_description()
139
+ action_description = environment.get_action_description()
140
+
141
+ # Initialize the statistics
142
+ frames = []
143
+ utility = 0
144
+ current_total_tokens = 0
145
+ current_total_cost = 0
146
+ start_time = datetime.datetime.now()
147
+ # Run the game for a maximum number of steps
148
+ for round in range(args.max_episode_len):
149
+ # Keep asking ChatGPT for an action until it provides a valid one
150
+ error_flag = True
151
+ retry_num = 1
152
+ for error_i in range(retry_num):
153
+ try:
154
+ action, prompt, response, tokens, cost = decider.act(
155
+ state_description,
156
+ action_description,
157
+ env_info,
158
+ game_description,
159
+ goal_description,
160
+ logfile
161
+ )
162
+
163
+ state_description, reward, termination, truncation, env_info = environment.step_llm(
164
+ action
165
+ )
166
+ if "Cliff" in args.env_name or "Frozen" in args.env_name:
167
+ decider.env_history.add('reward', env_info['potential_state'] + environment.reward_desc_dict[reward])
168
+ else:
169
+ decider.env_history.add('reward', f"The player get rewards {reward}.")
170
+
171
+ utility += reward
172
+
173
+ # Update the statistics
174
+ current_total_tokens += tokens
175
+ current_total_cost += cost
176
+ error_flag = False
177
+ break
178
+ except Exception as e:
179
+ print(e)
180
+ if error_i < retry_num-1:
181
+ if "Cliff" in args.env_name or "Frozen" in args.env_name:
182
+ decider.env_history.remove_invalid_state()
183
+ decider.env_history.remove_invalid_state()
184
+ if logger:
185
+ logger.debug(f"Error: {e}, Retry! ({error_i+1}/{retry_num})")
186
+ continue
187
+ if error_flag:
188
+ action = decider.default_action
189
+ state_description, reward, termination, truncation, env_info = environment.step_llm(
190
+ action
191
+ )
192
+
193
+ decider.env_history.add('action', decider.default_action)
194
+
195
+ if "Cliff" in args.env_name or "Frozen" in args.env_name:
196
+ # decider.env_history.add('reward', reward)
197
+ decider.env_history.add('reward', env_info['potential_state'] + environment.reward_desc_dict[reward])
198
+ utility += reward
199
+
200
+
201
+ logger.info(f"Seed: {seed}")
202
+ logger.info(f'The optimal action is: {decider.default_action}.')
203
+ logger.info(f"Now it is round {round}.")
204
+ else:
205
+ current_total_tokens += tokens
206
+ current_total_cost += cost
207
+ logger.info(f"Seed: {seed}")
208
+ logger.info(f"current_total_tokens: {current_total_tokens}")
209
+ logger.info(f"current_total_cost: {current_total_cost}")
210
+ logger.info(f"Now it is round {round}.")
211
+
212
+ # return results
213
+ yield environment.render(), state_description, prompt, response, action
214
+
215
+ if termination or truncation:
216
+ if logger:
217
+ logger.info(f"Terminated!")
218
+ break
219
+ time.sleep(10)
220
+ decider.env_history.add(
221
+ 'terminate_state', environment.get_terminate_state(round+1, args.max_episode_len))
222
+ decider.env_history.add("cummulative_reward", str(utility))
223
+ # Record the final reward
224
+ if logger:
225
+ logger.info(f"Cummulative reward: {utility}.")
226
+ end_time = datetime.datetime.now()
227
+ time_diff = end_time - start_time
228
+ logger.info(f"Time consumer: {time_diff.total_seconds()} s")
229
+
230
+ utilities.append(utility)
231
+ # TODO: set env sucess utility threshold
232
+ if trail < num_trails -1:
233
+ if args.decider in ['reflexion']:
234
+ if utility < expert_score:
235
+ decider.update_mem()
236
+ else:
237
+ decider.update_mem()
238
+ decider.clear_mem()
239
+ return utilities
240
+
241
+ # def pause():
242
+ # for i in range(31415926):
243
+ # time.sleep(0.1)
244
+ # yield i
245
+
246
+ if __name__ == "__main__":
247
+ custom_css = """
248
+ #render {
249
+ flex-grow: 1;
250
+ }
251
+ #input_text .tabs {
252
+ display: flex;
253
+ flex-direction: column;
254
+ flex-grow: 1;
255
+ }
256
+ #input_text .tabitem[style="display: block;"] {
257
+ flex-grow: 1;
258
+ display: flex !important;
259
+ }
260
+ #input_text .gap {
261
+ flex-grow: 1;
262
+ }
263
+ #input_text .form {
264
+ flex-grow: 1 !important;
265
+ }
266
+ #input_text .form > :last-child{
267
+ flex-grow: 1;
268
+ }
269
+ """
270
+
271
+ with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo:
272
+ with gr.Row():
273
+ env_name = gr.Dropdown(
274
+ ["RepresentedBoxing-v0",
275
+ "RepresentedPong-v0",
276
+ "RepresentedMsPacman-v0",
277
+ "RepresentedMontezumaRevenge-v0"],
278
+ label="Environment Name")
279
+ decider = gr.Dropdown(
280
+ ["naive_actor",
281
+ "cot_actor",
282
+ "spp_actor",
283
+ "reflexion_actor"],
284
+ label="Decider")
285
+ prompt_level = gr.Dropdown([1, 2, 3, 4, 5], label="Prompt Level")
286
+ with gr.Row():
287
+ num_trails = gr.Slider(1, 100, 1, label="Number of Trails", scale=2)
288
+ seed = gr.Slider(1, 1000, 1, label="Seed", scale=2)
289
+ run = gr.Button("Run", scale=1)
290
+ # pause_ = gr.Button("Pause")
291
+ # resume = gr.Button("Resume")
292
+ stop = gr.Button("Stop", scale=1)
293
+ with gr.Row():
294
+ with gr.Column():
295
+ render = gr.Image(label="render", elem_id="render")
296
+ with gr.Column(elem_id="input_text"):
297
+ state = gr.Textbox(label="translated state")
298
+ prompt = gr.Textbox(label="prompt", max_lines=100)
299
+ with gr.Row():
300
+ response = gr.Textbox(label="response")
301
+ action = gr.Textbox(label="parsed action")
302
+ run_event = run.click(
303
+ fn=main_progress,
304
+ inputs=[env_name, decider, prompt_level, num_trails, seed],
305
+ outputs=[render, state, prompt, response, action])
306
+ stop.click(fn=None, inputs=None, outputs=None, cancels=[run_event])
307
+ # pause_event = pause_.click(fn=pause, inputs=None, outputs=None)
308
+ # resume.click(fn=None, inputs=None, outputs=None, cancels=[pause_event])
309
+
310
+ demo.launch(server_name="0.0.0.0", server_port=7860)
311
+
312
+