phython96 commited on
Commit
3c3d9d2
·
verified ·
1 Parent(s): ad3c511

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +286 -0
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ author: caishaofei <caishaofei@stu.pku.edu.cn>
3
+ date: 2024-09-20 20:10:44
4
+ Copyright © Team CraftJarvis All rights reserved
5
+ '''
6
+ import re
7
+ import os
8
+ import cv2
9
+ import time
10
+ from pathlib import Path
11
+ import argparse
12
+ import requests
13
+ import gradio as gr
14
+ import torch
15
+ import numpy as np
16
+ from io import BytesIO
17
+ from PIL import Image, ImageDraw
18
+ from rocket.arm.sessions import Session, Pointer
19
+
20
+ COLORS = [
21
+ (255, 0, 0), (0, 255, 0), (0, 0, 255),
22
+ (255, 255, 0), (255, 0, 255), (0, 255, 255),
23
+ (255, 255, 255), (0, 0, 0), (128, 128, 128),
24
+ (128, 0, 0), (128, 128, 0), (0, 128, 0),
25
+ (128, 0, 128), (0, 128, 128), (0, 0, 128),
26
+ ]
27
+
28
+ SEGMENT_MAPPING = {
29
+ "Hunt": 0, "Use": 3, "Mine": 2, "Interact": 3, "Craft": 4, "Switch": 5, "Approach": 6
30
+ }
31
+
32
+ NOOP_ACTION = {
33
+ "back": 0,
34
+ "drop": 0,
35
+ "forward": 0,
36
+ "hotbar.1": 0,
37
+ "hotbar.2": 0,
38
+ "hotbar.3": 0,
39
+ "hotbar.4": 0,
40
+ "hotbar.5": 0,
41
+ "hotbar.6": 0,
42
+ "hotbar.7": 0,
43
+ "hotbar.8": 0,
44
+ "hotbar.9": 0,
45
+ "inventory": 0,
46
+ "jump": 0,
47
+ "left": 0,
48
+ "right": 0,
49
+ "sneak": 0,
50
+ "sprint": 0,
51
+ "camera": np.array([0, 0]),
52
+ "attack": 0,
53
+ "use": 0,
54
+ }
55
+
56
+ def reset_fn(env_name, session):
57
+ image = session.reset(env_name)
58
+ return image, session
59
+
60
+ def step_fn(act_key, session):
61
+ action = NOOP_ACTION.copy()
62
+ if act_key != "null":
63
+ action[act_key] = 1
64
+ image = session.step(action)
65
+ return image, session
66
+
67
+ def loop_step_fn(steps, session):
68
+ for i in range(steps):
69
+ image = session.step()
70
+ status = f"Running Agent `Rocket` steps: {i+1}/{steps}. "
71
+ yield image, session.num_steps, status, session
72
+
73
+ def clear_memory_fn(session):
74
+ image = session.current_image
75
+ session.clear_agent_memory()
76
+ return image, "0", session
77
+
78
+ def get_points_with_draw(image, label, session, evt: gr.SelectData):
79
+ points = session.points
80
+ point_label = session.points_label
81
+ x, y = evt.index[0], evt.index[1]
82
+ point_radius, point_color = 5, (0, 255, 0) if label == 'Add Points' else (255, 0, 0)
83
+ points.append([x, y])
84
+ point_label.append(1 if label == 'Add Points' else 0)
85
+ cv2.circle(image, (x, y), point_radius, point_color, -1)
86
+ return image, session
87
+
88
+ def clear_points_fn(session):
89
+ session.clear_points()
90
+ return session.current_image, session
91
+
92
+ def segment_fn(session):
93
+ if len(session.points) == 0:
94
+ return session.current_image, session
95
+ session.segment()
96
+ image = session.apply_mask()
97
+ return image, session
98
+
99
+ def clear_segment_fn(session):
100
+ session.clear_obj_mask()
101
+ session.tracking_flag = False
102
+ return session.current_image, False, session
103
+
104
+ def set_tracking_mode(tracking_flag, session):
105
+ session.tracking_flag = tracking_flag
106
+ return session
107
+
108
+ def set_segment_type(segment_type, session):
109
+ session.segment_type = segment_type
110
+ return session
111
+
112
+ def play_fn(session):
113
+ image = session.step()
114
+ return image, session
115
+
116
+ memory_length = gr.Textbox(value="0", interactive=False, show_label=False)
117
+
118
+ def make_video_fn(session, make_video, save_video, progress=gr.Progress()):
119
+ images = session.image_history
120
+ if len(images) == 0:
121
+ return session, make_video, save_video
122
+ filepath = "rocket.mp4"
123
+ h, w = images[0].shape[:2]
124
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
125
+ video = cv2.VideoWriter(filepath, fourcc, 20.0, (w, h))
126
+ for image in progress.tqdm(images):
127
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
128
+ video.write(image)
129
+ video.release()
130
+ session.image_history = []
131
+ return session, gr.Button("Make Video", visible=False), gr.DownloadButton("Download!", value=filepath, visible=True)
132
+
133
+ def save_video_fn(session, make_video, save_video):
134
+ return session, gr.Button("Make Video", visible=True), gr.DownloadButton("Download!", visible=False)
135
+
136
+ def choose_sam_fn(sam_choice, session):
137
+ session.sam_choice = sam_choice
138
+ session.load_sam()
139
+ return session
140
+
141
+ def molmo_fn(molmo_text, molmo_session, rocket_session, display_image):
142
+ image = rocket_session.current_image.copy()
143
+ points = molmo_session.gen_point(image=image, prompt=molmo_text)
144
+ molmo_result = molmo_session.molmo_result
145
+ for x, y in points:
146
+ x, y = int(x), int(y)
147
+ point_radius, point_color = 5, (0, 255, 0)
148
+ rocket_session.points.append([x, y])
149
+ rocket_session.points_label.append(1)
150
+ cv2.circle(display_image, (x, y), point_radius, point_color, -1)
151
+ return molmo_result, display_image
152
+
153
+ def extract_points(data):
154
+ # 匹配 x 和 y 坐标的值,支持 <points> 和 <point> 标签
155
+ pattern = r'x\d?="([-+]?\d*\.\d+|\d+)" y\d?="([-+]?\d*\.\d+|\d+)"'
156
+ points = re.findall(pattern, data)
157
+ # 将提取到的坐标转换为浮点数
158
+ points = [(float(x)/100*640, float(y)/100*360) for x, y in points]
159
+ return points
160
+
161
+ def draw_gradio_components(args):
162
+
163
+ with gr.Blocks() as demo:
164
+
165
+ gr.Markdown(
166
+ """
167
+ # Welcome to Explore ROCKET-1 in Minecraft!!
168
+ ## Please follow next steps to interact with the agent:
169
+ 1. Reset the environment by selecting an environment name.
170
+ 2. Select a SAM2 checkpoint to load.
171
+ 3. Use your mouse to add or remove points on the image.
172
+ 4. Select the segment type you want to perform.
173
+ 5. Enable `tracking` mode if you want to track objects while stepping actions.
174
+ 6. Click `New Segment` to segment the image based on the points you added.
175
+ 7. Call the agent by clicking `Call Rocket` to run the agent for a certain number of steps.
176
+ ## Hints:
177
+ 1. You can use the `Make Video` button to generate a video of the agent's actions.
178
+ 2. You can use the `Clear Memory` button to clear the ROCKET-1's memory.
179
+ 3. You can use the `Clear Segment` button to clear SAM's memory.
180
+ 4. You can use the `Manually Step` button to manually step the agent.
181
+ """
182
+ )
183
+
184
+ rocket_session = gr.State(Session(
185
+ sam_path=args.sam_path,
186
+ ))
187
+ molmo_session = gr.State(Pointer(
188
+ model_id="molmo-72b-0924",
189
+ model_url="http://172.17.30.127:8000/v1",
190
+ ))
191
+ with gr.Row():
192
+
193
+ with gr.Column(scale=2):
194
+ # start_image = Image.open("start.png").resize((640, 360))
195
+ start_image = np.zeros((360, 640, 3), dtype=np.uint8)
196
+
197
+ with gr.Group():
198
+ display_image = gr.Image(
199
+ value=np.array(start_image),
200
+ interactive=False,
201
+ show_label=False,
202
+ label="Real-time Environment Observation",
203
+ streaming=True
204
+ )
205
+ display_status = gr.Textbox("Status Bar", interactive=False, show_label=False)
206
+
207
+ with gr.Column(scale=1):
208
+
209
+ sam_choice = gr.Radio(
210
+ choices=["large", "base", "small", "tiny"],
211
+ value="base",
212
+ label="Select SAM2 checkpoint",
213
+ )
214
+ sam_choice.select(fn=choose_sam_fn, inputs=[sam_choice, rocket_session], outputs=[rocket_session], show_progress=False)
215
+
216
+ with gr.Group():
217
+ add_or_remove = gr.Radio(
218
+ choices=["Add Points", "Remove Areas"],
219
+ value="Add Points",
220
+ label="Use you mouse to add or remove points",
221
+ )
222
+ clear_points_btn = gr.Button("Clear Points")
223
+ clear_points_btn.click(clear_points_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True)
224
+
225
+ with gr.Group():
226
+ segment_type = gr.Radio(
227
+ choices=["Approach", "Interact", "Hunt", "Mine", "Craft", "Switch"],
228
+ value="Approach",
229
+ label="What do you want with this segment?",
230
+ )
231
+ track_flag = gr.Checkbox(True, label="Enable tracking objects while steping actions")
232
+ track_flag.select(fn=set_tracking_mode, inputs=[track_flag, rocket_session], outputs=[rocket_session], show_progress=False)
233
+ with gr.Group(), gr.Row():
234
+ new_segment_btn = gr.Button("New Segment")
235
+ clear_segment_btn = gr.Button("Clear Segment")
236
+ new_segment_btn.click(segment_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True)
237
+ clear_segment_btn.click(clear_segment_fn, inputs=[rocket_session], outputs=[display_image, track_flag, rocket_session], show_progress=True)
238
+
239
+ display_image.select(get_points_with_draw, inputs=[display_image, add_or_remove, rocket_session], outputs=[display_image, rocket_session])
240
+ segment_type.select(set_segment_type, inputs=[segment_type, rocket_session], outputs=[rocket_session], show_progress=False)
241
+
242
+ with gr.Row():
243
+ with gr.Group():
244
+ env_list = [f"rocket/{x.stem}" for x in Path("../env_configs/rocket").glob("*.yaml") if 'base' not in x.name != 'base']
245
+ env_name = gr.Dropdown(env_list, multiselect=False, min_width=200, show_label=False, label="Env Name")
246
+ reset_btn = gr.Button("Reset Environment")
247
+ reset_btn.click(fn=reset_fn, inputs=[env_name, rocket_session], outputs=[display_image, rocket_session], show_progress=True)
248
+
249
+ with gr.Group():
250
+ action_list = [x for x in NOOP_ACTION.keys()]
251
+ act_key = gr.Dropdown(action_list, multiselect=False, min_width=200, show_label=False, label="Action")
252
+ step_btn = gr.Button("Manually Step")
253
+ step_btn.click(fn=step_fn, inputs=[act_key, rocket_session], outputs=[display_image, rocket_session], show_progress=False)
254
+
255
+ with gr.Group():
256
+ steps = gr.Slider(1, 600, 30, 1, label="Steps", show_label=False)
257
+ play_btn = gr.Button("Call Rocket")
258
+ play_btn.click(fn=loop_step_fn, inputs=[steps, rocket_session], outputs=[display_image, memory_length, display_status, rocket_session], show_progress=False)
259
+
260
+ with gr.Group():
261
+ memory_length.render()
262
+ clear_states_btn = gr.Button("Clear Memory")
263
+ clear_states_btn.click(fn=clear_memory_fn, inputs=rocket_session, outputs=[display_image, memory_length, rocket_session], show_progress=False)
264
+
265
+ make_video_btn = gr.Button("Make Video")
266
+ save_video_btn = gr.DownloadButton("Download!!", visible=False)
267
+ make_video_btn.click(make_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False)
268
+ save_video_btn.click(save_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False)
269
+ with gr.Row():
270
+ with gr.Group():
271
+ molmo_text = gr.Textbox("pinpoint the", label="Molmo Text", show_label=True, min_width=200)
272
+ molmo_btn = gr.Button("Generate")
273
+ output_text = gr.Textbox("", label="Molmo Output", show_label=False, min_width=200)
274
+ molmo_btn.click(molmo_fn, inputs=[molmo_text, molmo_session, rocket_session, display_image],outputs=[output_text, display_image],show_progress=False)
275
+
276
+ demo.queue()
277
+ demo.launch(share=False,server_port=args.port)
278
+
279
+ if __name__ == '__main__':
280
+ parser = argparse.ArgumentParser()
281
+ parser.add_argument("--port", type=int, default=7860)
282
+ parser.add_argument("--sam-path", type=str, default="/app/ROCKET-1/rocket/realtime_sam/checkpoints")
283
+ parser.add_argument("--molmo-id", type=str, default="molmo-72b-0924")
284
+ parser.add_argument("--molmo-url", type=str, default="http://127.0.0.1:8000/v1")
285
+ args = parser.parse_args()
286
+ draw_gradio_components(args)