Spaces:
Paused
Paused
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
author: caishaofei <caishaofei@stu.pku.edu.cn>
|
3 |
+
date: 2024-09-20 20:10:44
|
4 |
+
Copyright © Team CraftJarvis All rights reserved
|
5 |
+
'''
|
6 |
+
import re
|
7 |
+
import os
|
8 |
+
import cv2
|
9 |
+
import time
|
10 |
+
from pathlib import Path
|
11 |
+
import argparse
|
12 |
+
import requests
|
13 |
+
import gradio as gr
|
14 |
+
import torch
|
15 |
+
import numpy as np
|
16 |
+
from io import BytesIO
|
17 |
+
from PIL import Image, ImageDraw
|
18 |
+
from rocket.arm.sessions import Session, Pointer
|
19 |
+
|
20 |
+
COLORS = [
|
21 |
+
(255, 0, 0), (0, 255, 0), (0, 0, 255),
|
22 |
+
(255, 255, 0), (255, 0, 255), (0, 255, 255),
|
23 |
+
(255, 255, 255), (0, 0, 0), (128, 128, 128),
|
24 |
+
(128, 0, 0), (128, 128, 0), (0, 128, 0),
|
25 |
+
(128, 0, 128), (0, 128, 128), (0, 0, 128),
|
26 |
+
]
|
27 |
+
|
28 |
+
SEGMENT_MAPPING = {
|
29 |
+
"Hunt": 0, "Use": 3, "Mine": 2, "Interact": 3, "Craft": 4, "Switch": 5, "Approach": 6
|
30 |
+
}
|
31 |
+
|
32 |
+
NOOP_ACTION = {
|
33 |
+
"back": 0,
|
34 |
+
"drop": 0,
|
35 |
+
"forward": 0,
|
36 |
+
"hotbar.1": 0,
|
37 |
+
"hotbar.2": 0,
|
38 |
+
"hotbar.3": 0,
|
39 |
+
"hotbar.4": 0,
|
40 |
+
"hotbar.5": 0,
|
41 |
+
"hotbar.6": 0,
|
42 |
+
"hotbar.7": 0,
|
43 |
+
"hotbar.8": 0,
|
44 |
+
"hotbar.9": 0,
|
45 |
+
"inventory": 0,
|
46 |
+
"jump": 0,
|
47 |
+
"left": 0,
|
48 |
+
"right": 0,
|
49 |
+
"sneak": 0,
|
50 |
+
"sprint": 0,
|
51 |
+
"camera": np.array([0, 0]),
|
52 |
+
"attack": 0,
|
53 |
+
"use": 0,
|
54 |
+
}
|
55 |
+
|
56 |
+
def reset_fn(env_name, session):
|
57 |
+
image = session.reset(env_name)
|
58 |
+
return image, session
|
59 |
+
|
60 |
+
def step_fn(act_key, session):
|
61 |
+
action = NOOP_ACTION.copy()
|
62 |
+
if act_key != "null":
|
63 |
+
action[act_key] = 1
|
64 |
+
image = session.step(action)
|
65 |
+
return image, session
|
66 |
+
|
67 |
+
def loop_step_fn(steps, session):
|
68 |
+
for i in range(steps):
|
69 |
+
image = session.step()
|
70 |
+
status = f"Running Agent `Rocket` steps: {i+1}/{steps}. "
|
71 |
+
yield image, session.num_steps, status, session
|
72 |
+
|
73 |
+
def clear_memory_fn(session):
|
74 |
+
image = session.current_image
|
75 |
+
session.clear_agent_memory()
|
76 |
+
return image, "0", session
|
77 |
+
|
78 |
+
def get_points_with_draw(image, label, session, evt: gr.SelectData):
|
79 |
+
points = session.points
|
80 |
+
point_label = session.points_label
|
81 |
+
x, y = evt.index[0], evt.index[1]
|
82 |
+
point_radius, point_color = 5, (0, 255, 0) if label == 'Add Points' else (255, 0, 0)
|
83 |
+
points.append([x, y])
|
84 |
+
point_label.append(1 if label == 'Add Points' else 0)
|
85 |
+
cv2.circle(image, (x, y), point_radius, point_color, -1)
|
86 |
+
return image, session
|
87 |
+
|
88 |
+
def clear_points_fn(session):
|
89 |
+
session.clear_points()
|
90 |
+
return session.current_image, session
|
91 |
+
|
92 |
+
def segment_fn(session):
|
93 |
+
if len(session.points) == 0:
|
94 |
+
return session.current_image, session
|
95 |
+
session.segment()
|
96 |
+
image = session.apply_mask()
|
97 |
+
return image, session
|
98 |
+
|
99 |
+
def clear_segment_fn(session):
|
100 |
+
session.clear_obj_mask()
|
101 |
+
session.tracking_flag = False
|
102 |
+
return session.current_image, False, session
|
103 |
+
|
104 |
+
def set_tracking_mode(tracking_flag, session):
|
105 |
+
session.tracking_flag = tracking_flag
|
106 |
+
return session
|
107 |
+
|
108 |
+
def set_segment_type(segment_type, session):
|
109 |
+
session.segment_type = segment_type
|
110 |
+
return session
|
111 |
+
|
112 |
+
def play_fn(session):
|
113 |
+
image = session.step()
|
114 |
+
return image, session
|
115 |
+
|
116 |
+
memory_length = gr.Textbox(value="0", interactive=False, show_label=False)
|
117 |
+
|
118 |
+
def make_video_fn(session, make_video, save_video, progress=gr.Progress()):
|
119 |
+
images = session.image_history
|
120 |
+
if len(images) == 0:
|
121 |
+
return session, make_video, save_video
|
122 |
+
filepath = "rocket.mp4"
|
123 |
+
h, w = images[0].shape[:2]
|
124 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
125 |
+
video = cv2.VideoWriter(filepath, fourcc, 20.0, (w, h))
|
126 |
+
for image in progress.tqdm(images):
|
127 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
128 |
+
video.write(image)
|
129 |
+
video.release()
|
130 |
+
session.image_history = []
|
131 |
+
return session, gr.Button("Make Video", visible=False), gr.DownloadButton("Download!", value=filepath, visible=True)
|
132 |
+
|
133 |
+
def save_video_fn(session, make_video, save_video):
|
134 |
+
return session, gr.Button("Make Video", visible=True), gr.DownloadButton("Download!", visible=False)
|
135 |
+
|
136 |
+
def choose_sam_fn(sam_choice, session):
|
137 |
+
session.sam_choice = sam_choice
|
138 |
+
session.load_sam()
|
139 |
+
return session
|
140 |
+
|
141 |
+
def molmo_fn(molmo_text, molmo_session, rocket_session, display_image):
|
142 |
+
image = rocket_session.current_image.copy()
|
143 |
+
points = molmo_session.gen_point(image=image, prompt=molmo_text)
|
144 |
+
molmo_result = molmo_session.molmo_result
|
145 |
+
for x, y in points:
|
146 |
+
x, y = int(x), int(y)
|
147 |
+
point_radius, point_color = 5, (0, 255, 0)
|
148 |
+
rocket_session.points.append([x, y])
|
149 |
+
rocket_session.points_label.append(1)
|
150 |
+
cv2.circle(display_image, (x, y), point_radius, point_color, -1)
|
151 |
+
return molmo_result, display_image
|
152 |
+
|
153 |
+
def extract_points(data):
|
154 |
+
# 匹配 x 和 y 坐标的值,支持 <points> 和 <point> 标签
|
155 |
+
pattern = r'x\d?="([-+]?\d*\.\d+|\d+)" y\d?="([-+]?\d*\.\d+|\d+)"'
|
156 |
+
points = re.findall(pattern, data)
|
157 |
+
# 将提取到的坐标转换为浮点数
|
158 |
+
points = [(float(x)/100*640, float(y)/100*360) for x, y in points]
|
159 |
+
return points
|
160 |
+
|
161 |
+
def draw_gradio_components(args):
|
162 |
+
|
163 |
+
with gr.Blocks() as demo:
|
164 |
+
|
165 |
+
gr.Markdown(
|
166 |
+
"""
|
167 |
+
# Welcome to Explore ROCKET-1 in Minecraft!!
|
168 |
+
## Please follow next steps to interact with the agent:
|
169 |
+
1. Reset the environment by selecting an environment name.
|
170 |
+
2. Select a SAM2 checkpoint to load.
|
171 |
+
3. Use your mouse to add or remove points on the image.
|
172 |
+
4. Select the segment type you want to perform.
|
173 |
+
5. Enable `tracking` mode if you want to track objects while stepping actions.
|
174 |
+
6. Click `New Segment` to segment the image based on the points you added.
|
175 |
+
7. Call the agent by clicking `Call Rocket` to run the agent for a certain number of steps.
|
176 |
+
## Hints:
|
177 |
+
1. You can use the `Make Video` button to generate a video of the agent's actions.
|
178 |
+
2. You can use the `Clear Memory` button to clear the ROCKET-1's memory.
|
179 |
+
3. You can use the `Clear Segment` button to clear SAM's memory.
|
180 |
+
4. You can use the `Manually Step` button to manually step the agent.
|
181 |
+
"""
|
182 |
+
)
|
183 |
+
|
184 |
+
rocket_session = gr.State(Session(
|
185 |
+
sam_path=args.sam_path,
|
186 |
+
))
|
187 |
+
molmo_session = gr.State(Pointer(
|
188 |
+
model_id="molmo-72b-0924",
|
189 |
+
model_url="http://172.17.30.127:8000/v1",
|
190 |
+
))
|
191 |
+
with gr.Row():
|
192 |
+
|
193 |
+
with gr.Column(scale=2):
|
194 |
+
# start_image = Image.open("start.png").resize((640, 360))
|
195 |
+
start_image = np.zeros((360, 640, 3), dtype=np.uint8)
|
196 |
+
|
197 |
+
with gr.Group():
|
198 |
+
display_image = gr.Image(
|
199 |
+
value=np.array(start_image),
|
200 |
+
interactive=False,
|
201 |
+
show_label=False,
|
202 |
+
label="Real-time Environment Observation",
|
203 |
+
streaming=True
|
204 |
+
)
|
205 |
+
display_status = gr.Textbox("Status Bar", interactive=False, show_label=False)
|
206 |
+
|
207 |
+
with gr.Column(scale=1):
|
208 |
+
|
209 |
+
sam_choice = gr.Radio(
|
210 |
+
choices=["large", "base", "small", "tiny"],
|
211 |
+
value="base",
|
212 |
+
label="Select SAM2 checkpoint",
|
213 |
+
)
|
214 |
+
sam_choice.select(fn=choose_sam_fn, inputs=[sam_choice, rocket_session], outputs=[rocket_session], show_progress=False)
|
215 |
+
|
216 |
+
with gr.Group():
|
217 |
+
add_or_remove = gr.Radio(
|
218 |
+
choices=["Add Points", "Remove Areas"],
|
219 |
+
value="Add Points",
|
220 |
+
label="Use you mouse to add or remove points",
|
221 |
+
)
|
222 |
+
clear_points_btn = gr.Button("Clear Points")
|
223 |
+
clear_points_btn.click(clear_points_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True)
|
224 |
+
|
225 |
+
with gr.Group():
|
226 |
+
segment_type = gr.Radio(
|
227 |
+
choices=["Approach", "Interact", "Hunt", "Mine", "Craft", "Switch"],
|
228 |
+
value="Approach",
|
229 |
+
label="What do you want with this segment?",
|
230 |
+
)
|
231 |
+
track_flag = gr.Checkbox(True, label="Enable tracking objects while steping actions")
|
232 |
+
track_flag.select(fn=set_tracking_mode, inputs=[track_flag, rocket_session], outputs=[rocket_session], show_progress=False)
|
233 |
+
with gr.Group(), gr.Row():
|
234 |
+
new_segment_btn = gr.Button("New Segment")
|
235 |
+
clear_segment_btn = gr.Button("Clear Segment")
|
236 |
+
new_segment_btn.click(segment_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True)
|
237 |
+
clear_segment_btn.click(clear_segment_fn, inputs=[rocket_session], outputs=[display_image, track_flag, rocket_session], show_progress=True)
|
238 |
+
|
239 |
+
display_image.select(get_points_with_draw, inputs=[display_image, add_or_remove, rocket_session], outputs=[display_image, rocket_session])
|
240 |
+
segment_type.select(set_segment_type, inputs=[segment_type, rocket_session], outputs=[rocket_session], show_progress=False)
|
241 |
+
|
242 |
+
with gr.Row():
|
243 |
+
with gr.Group():
|
244 |
+
env_list = [f"rocket/{x.stem}" for x in Path("../env_configs/rocket").glob("*.yaml") if 'base' not in x.name != 'base']
|
245 |
+
env_name = gr.Dropdown(env_list, multiselect=False, min_width=200, show_label=False, label="Env Name")
|
246 |
+
reset_btn = gr.Button("Reset Environment")
|
247 |
+
reset_btn.click(fn=reset_fn, inputs=[env_name, rocket_session], outputs=[display_image, rocket_session], show_progress=True)
|
248 |
+
|
249 |
+
with gr.Group():
|
250 |
+
action_list = [x for x in NOOP_ACTION.keys()]
|
251 |
+
act_key = gr.Dropdown(action_list, multiselect=False, min_width=200, show_label=False, label="Action")
|
252 |
+
step_btn = gr.Button("Manually Step")
|
253 |
+
step_btn.click(fn=step_fn, inputs=[act_key, rocket_session], outputs=[display_image, rocket_session], show_progress=False)
|
254 |
+
|
255 |
+
with gr.Group():
|
256 |
+
steps = gr.Slider(1, 600, 30, 1, label="Steps", show_label=False)
|
257 |
+
play_btn = gr.Button("Call Rocket")
|
258 |
+
play_btn.click(fn=loop_step_fn, inputs=[steps, rocket_session], outputs=[display_image, memory_length, display_status, rocket_session], show_progress=False)
|
259 |
+
|
260 |
+
with gr.Group():
|
261 |
+
memory_length.render()
|
262 |
+
clear_states_btn = gr.Button("Clear Memory")
|
263 |
+
clear_states_btn.click(fn=clear_memory_fn, inputs=rocket_session, outputs=[display_image, memory_length, rocket_session], show_progress=False)
|
264 |
+
|
265 |
+
make_video_btn = gr.Button("Make Video")
|
266 |
+
save_video_btn = gr.DownloadButton("Download!!", visible=False)
|
267 |
+
make_video_btn.click(make_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False)
|
268 |
+
save_video_btn.click(save_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False)
|
269 |
+
with gr.Row():
|
270 |
+
with gr.Group():
|
271 |
+
molmo_text = gr.Textbox("pinpoint the", label="Molmo Text", show_label=True, min_width=200)
|
272 |
+
molmo_btn = gr.Button("Generate")
|
273 |
+
output_text = gr.Textbox("", label="Molmo Output", show_label=False, min_width=200)
|
274 |
+
molmo_btn.click(molmo_fn, inputs=[molmo_text, molmo_session, rocket_session, display_image],outputs=[output_text, display_image],show_progress=False)
|
275 |
+
|
276 |
+
demo.queue()
|
277 |
+
demo.launch(share=False,server_port=args.port)
|
278 |
+
|
279 |
+
if __name__ == '__main__':
|
280 |
+
parser = argparse.ArgumentParser()
|
281 |
+
parser.add_argument("--port", type=int, default=7860)
|
282 |
+
parser.add_argument("--sam-path", type=str, default="/app/ROCKET-1/rocket/realtime_sam/checkpoints")
|
283 |
+
parser.add_argument("--molmo-id", type=str, default="molmo-72b-0924")
|
284 |
+
parser.add_argument("--molmo-url", type=str, default="http://127.0.0.1:8000/v1")
|
285 |
+
args = parser.parse_args()
|
286 |
+
draw_gradio_components(args)
|